diff --git a/poetry.lock b/poetry.lock index 6c475da8..29ebefd8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1227,6 +1227,23 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "resolvelib" +version = "1.2.1" +description = "Resolve abstract dependencies into concrete ones" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "resolvelib-1.2.1-py3-none-any.whl", hash = "sha256:fb06b66c8da04172d9e72a21d7d06186d8919e32ae5ab5cdf5b9d920be805ac2"}, + {file = "resolvelib-1.2.1.tar.gz", hash = "sha256:7d08a2022f6e16ce405d60b68c390f054efcfd0477d4b9bd019cc941c28fad1c"}, +] + +[package.extras] +lint = ["mypy", "ruff", "types-requests"] +release = ["build", "towncrier", "twine"] +test = ["packaging", "pytest"] + [[package]] name = "ruff" version = "0.11.12" @@ -1495,4 +1512,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "517b3cfcd0189b121dee6198f71b2f23a293c49cca6108cb79cbbe0433d9fc50" +content-hash = "ae3d3a36f3e5e2b55b49c4b33125b5fa7e38100706fc1682bde277e63afccdcc" diff --git a/pyproject.toml b/pyproject.toml index 63e37224..8869b45f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "questionpy-server" description = "QuestionPy application server" license = { file = "LICENSE.md" } urls = { homepage = "https://questionpy.org" } -version = "0.8.0" +version = "0.9.0" authors = [ { name = "TU Berlin innoCampus" }, { email = "info@isis.tu-berlin.de" } @@ -21,7 +21,8 @@ dependencies = [ "semver >=3.0.4, <4.0.0", "psutil >=7.0.0, <8.0.0", "jinja2 >=3.1.6, <4.0.0", - "pyyaml >=6.0.2, <7.0.0" + "pyyaml >=6.0.2, <7.0.0", + "resolvelib >=1.2.1, <2.0.0" ] [tool.poetry.group.dev.dependencies] diff --git a/questionpy_common/__init__.py b/questionpy_common/__init__.py index ee93b02c..13a7f032 100644 --- a/questionpy_common/__init__.py +++ b/questionpy_common/__init__.py @@ -2,7 +2,9 @@ # QuestionPy is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus from abc import ABC, abstractmethod +from typing import Any, NamedTuple, Self +from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema @@ -20,3 +22,32 @@ def __get_pydantic_core_schema__(cls, *_: object) -> CoreSchema: TranslatableString.register(str) + + +class PackageNamespaceAndShortName(NamedTuple): + """Tuple of namespace and short name, identifying any version of a specific package.""" + + namespace: str + short_name: str + + def __str__(self) -> str: + return f"@{self.namespace}/{self.short_name}" + + @classmethod + def from_string(cls, value: str) -> Self: + """Parse an NSSN in the same format as produced by `__str__`.""" + value = value.strip() + if not value.startswith("@") or value.count("/") != 1: + msg = f"Invalid package identifier (NSSN): '{value}'" + raise ValueError(msg) + + ns, sn = value.removeprefix("@").split("/", maxsplit=1) + return cls(ns, sn) + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handle: GetCoreSchemaHandler) -> CoreSchema: + return core_schema.no_info_before_validator_function( + lambda obj: cls.from_string(obj) if isinstance(obj, str) else obj, + handle(source_type), + serialization=core_schema.to_string_ser_schema(), + ) diff --git a/questionpy_common/constants.py b/questionpy_common/constants.py index cf741e0d..929dbf10 100644 --- a/questionpy_common/constants.py +++ b/questionpy_common/constants.py @@ -27,4 +27,19 @@ r"^([a-zA-Z_][a-zA-Z0-9_]*|\.\.)(\[([a-zA-Z_][a-zA-Z0-9_]*|\.\.)?])*$" ) +# Regular expressions. + ENVIRONMENT_VARIABLE_REGEX: Final[str] = r"[a-zA-Z_][a-zA-Z0-9_]*" + +RE_SEMVER = ( + r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)" + r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" +) + +RE_API = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)$" + +# The SemVer and Api version patterns are used on pydantic fields, which uses Rust regexes, so re.compiling them makes +# no sense. We match RE_VALID_CHARS_NAME in Python though, so here it does. +RE_VALID_CHARS_NAME = re.compile(r"^[a-z\d_]+$") + +NAME_MAX_LENGTH = 127 diff --git a/questionpy_common/dependencies.py b/questionpy_common/dependencies.py new file mode 100644 index 00000000..6a8eabd8 --- /dev/null +++ b/questionpy_common/dependencies.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass, field +from typing import Annotated + +from pydantic import Field + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.constants import RE_SEMVER +from questionpy_common.manifest import DistDependencies +from questionpy_common.package_location import PackageLocation + + +@dataclass(frozen=True) +class StaticDependencySolution: + """Indicates that a package in the tree provides a static dependency that is to be used. + + Usually, this is a solution for the static dependency itself, but if there is also a dynamic dependency in the tree + for that NSSN _and_ that dynamic dependency allows the static version, a `StaticDependencySolution` might be used to + also solve a dynamic dependency. + + If multiple packages provide the same version of a static dependency, any of them may be used as the solution. + Static dependencies on different versions of the same NSSN will always lead to a `DependencyConflictError`. + """ + + nssn: PackageNamespaceAndShortName + + owner: PackageNamespaceAndShortName + """The package that includes this static dependency.""" + + hash: str + # semver.Version is avoided to allow solutions to be passed to the package. + version: Annotated[str, Field(pattern=RE_SEMVER)] + + dependencies: DistDependencies = field(compare=False) + """Transitive dependencies of this dependency.""" + + def __str__(self) -> str: + return f"{self.hash} ({self.version}, statically packaged in '{self.owner}')" + + +@dataclass(frozen=True) +class DynamicDependencySolution: + """Indicates that the given version should be used to supply all usages of the NSSN.""" + + nssn: PackageNamespaceAndShortName + + hash: str + # semver.Version is avoided to allow solutions to be passed to the package. + version: Annotated[str, Field(pattern=RE_SEMVER)] + dependencies: DistDependencies = field(compare=False) + """Transitive dependencies of this dependency.""" + + def __str__(self) -> str: + return f"{self.hash} ({self.version}, dynamic)" + + +type DependencySolution = StaticDependencySolution | DynamicDependencySolution + +type SolutionAndLocation = tuple[DynamicDependencySolution, PackageLocation] | tuple[StaticDependencySolution, None] diff --git a/questionpy_common/environment.py b/questionpy_common/environment.py index 04fea5cb..27eecb16 100644 --- a/questionpy_common/environment.py +++ b/questionpy_common/environment.py @@ -7,10 +7,11 @@ from enum import Enum from functools import total_ordering from importlib.resources.abc import Traversable -from typing import NamedTuple, Protocol +from typing import Protocol from pydantic import BaseModel, JsonValue +from questionpy_common import PackageNamespaceAndShortName from questionpy_common.api.package import QPyPackageInterface from questionpy_common.manifest import Bcp47LanguageTag, Manifest @@ -21,7 +22,6 @@ "OnRequestCallback", "Package", "PackageInitFunction", - "PackageNamespaceAndShortName", "PackageNotInitializedError", "PackageNotLoadedError", "PackagePermissions", @@ -79,16 +79,6 @@ def __lt__(self, other: object) -> bool: return NotImplemented -class PackageNamespaceAndShortName(NamedTuple): - """Tuple of namespace and short name, identifying any version of a specific package.""" - - namespace: str - short_name: str - - def __str__(self) -> str: - return f"@{self.namespace}/{self.short_name}" - - class Package(Protocol): @property def manifest(self) -> Manifest: ... diff --git a/questionpy_common/manifest.py b/questionpy_common/manifest.py index 39370c6c..1f4c52bc 100644 --- a/questionpy_common/manifest.py +++ b/questionpy_common/manifest.py @@ -2,15 +2,31 @@ # QuestionPy is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus -import re +from abc import ABC from enum import StrEnum from keyword import iskeyword, issoftkeyword -from typing import Annotated, NewType - -from pydantic import BaseModel, ByteSize, PositiveInt, StringConstraints, conset, field_validator +from typing import Annotated, Literal, NewType + +from pydantic import ( + AfterValidator, + BaseModel, + ByteSize, + PositiveInt, + StringConstraints, + conset, + field_validator, +) from pydantic.fields import Field -from questionpy_common.constants import ENVIRONMENT_VARIABLE_REGEX +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.constants import ( + ENVIRONMENT_VARIABLE_REGEX, + NAME_MAX_LENGTH, + RE_API, + RE_SEMVER, + RE_VALID_CHARS_NAME, +) +from questionpy_common.version_specifiers import QPyDependencyVersionSpecifier class PackageType(StrEnum): @@ -19,22 +35,9 @@ class PackageType(StrEnum): QUESTION = "QUESTION" -# Defaults. DEFAULT_NAMESPACE = "local" DEFAULT_PACKAGETYPE = PackageType.QUESTIONTYPE -# Regular expressions. -RE_SEMVER = ( - r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)" - r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" -) -RE_API = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)$" -# The SemVer and Api version patterns are used on pydantic fields, which uses Rust regexes, so re.compiling them makes -# no sense. We match RE_VALID_CHARS_NAME in Python though, so here it does. -RE_VALID_CHARS_NAME = re.compile(r"^[a-z\d_]+$") - -NAME_MAX_LENGTH = 127 - # Validators. def ensure_is_valid_name(name: str) -> str: @@ -139,6 +142,10 @@ def ensure_contains_english_translation( def identifier(self) -> str: return f"@{self.namespace}/{self.short_name}" + @property + def nssn(self) -> PackageNamespaceAndShortName: + return PackageNamespaceAndShortName(self.namespace, self.short_name) + class PackageFile(BaseModel): """Represents a static file included in a built package.""" @@ -148,13 +155,46 @@ class PackageFile(BaseModel): class DistStaticQPyDependency(BaseModel): - dir_name: str - """Name (without `dist/dependencies/qpy/`) of the directory the dependency package contents reside in.""" + namespace: Annotated[str, AfterValidator(ensure_is_valid_name)] + short_name: Annotated[str, AfterValidator(ensure_is_valid_name)] + version: Annotated[str, Field(pattern=RE_SEMVER)] + + dependencies: "DistDependencies" + """Transitive dependencies of this dependency.""" + hash: str - """Hash of the ZIP package whose contents lie in `dir_name`.""" + """Hash of the ZIP package whose contents are included in this package.""" + + @property + def nssn(self) -> PackageNamespaceAndShortName: + return PackageNamespaceAndShortName(self.namespace, self.short_name) + + +type DependencyLockStrategy = Literal["required", "preferred-no-downgrade", "preferred-allow-downgrade"] + + +class LockedDependencyInfo(BaseModel): + strategy: DependencyLockStrategy + locked_version: Annotated[str, Field(pattern=RE_SEMVER)] + locked_hash: str + + +class AbstractDynamicQPyDependency(BaseModel, ABC): + namespace: Annotated[str, AfterValidator(ensure_is_valid_name)] + short_name: Annotated[str, AfterValidator(ensure_is_valid_name)] + version: QPyDependencyVersionSpecifier | None = None + include_prereleases: bool = False + + @property + def nssn(self) -> PackageNamespaceAndShortName: + return PackageNamespaceAndShortName(self.namespace, self.short_name) + + +class DistDynamicQPyDependency(AbstractDynamicQPyDependency): + locked: LockedDependencyInfo | None = None -type DistQPyDependency = DistStaticQPyDependency +type DistQPyDependency = DistStaticQPyDependency | DistDynamicQPyDependency class DistDependencies(BaseModel): diff --git a/questionpy_server/worker/runtime/package_location.py b/questionpy_common/package_location.py similarity index 100% rename from questionpy_server/worker/runtime/package_location.py rename to questionpy_common/package_location.py diff --git a/questionpy_common/version_specifiers.py b/questionpy_common/version_specifiers.py new file mode 100644 index 00000000..039f0aeb --- /dev/null +++ b/questionpy_common/version_specifiers.py @@ -0,0 +1,136 @@ +import re +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Literal, Protocol, Self + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + +from questionpy_common.constants import RE_SEMVER + +type _Operator = Literal["==", "!=", ">=", "<=", ">", "<", "^="] +_OPERATORS: tuple[_Operator, ...] = ("==", "!=", ">=", "<=", ">", "<", "^=") + +_SEMVER_PATTERN = re.compile(RE_SEMVER) + + +class VersionProtocol(Protocol): + """Partial protocol for SemVer version objects. + + We don't want `questionpy_common` to depend on the `semver` package, so we define this protocol instead of using + `semver.Version` directly. + """ + + @staticmethod + def parse(string: str) -> "VersionProtocol": ... + + def is_compatible(self, other: Self) -> bool: ... + + def __gt__(self, other: Self) -> bool: ... + + def __ge__(self, other: Self) -> bool: ... + + def __lt__(self, other: Self) -> bool: ... + + def __le__(self, other: Self) -> bool: ... + + +@dataclass(frozen=True) +class QPyDependencyVersionSpecifier: + """One or more clauses restricting allowed versions for a QPy package dependency.""" + + @dataclass(frozen=True) + class Clause: + """A single comparison clause such as `>= 1.2.2`.""" + + operator: _Operator + operand: str + + def __post_init__(self) -> None: + if self.operator == "^=" and "-" in self.operand: + # Prereleases are never compatible with different prereleases, so this would make little sense. + msg = "The '^=' operator cannot be used with prereleases." + raise ValueError(msg) + + def allows(self, version: VersionProtocol) -> bool: + """Check if this clause is fulfilled by the given version.""" + # Note: The semver package we use does already implement a `match` method, but we would like to validate + # each clause early, before the matching needs to be done. + parsed_operand = type(version).parse(self.operand) + match self.operator: + case "<": + return version < parsed_operand + case "<=": + return version <= parsed_operand + case "==": + return version == parsed_operand + case ">=": + return version >= parsed_operand + case ">": + return version > parsed_operand + case "^=": + return parsed_operand.is_compatible(version) + case _: + # Shouldn't be reachable. + msg = f"Invalid operator: {self.operator}" + raise ValueError(msg) + + @classmethod + def from_string(cls, string: str) -> Self: + string = string.strip() + + operator = next(filter(string.startswith, _OPERATORS), None) + if operator: + version_string = string.removeprefix(operator).lstrip() + if not _SEMVER_PATTERN.match(version_string): + msg = f"Comparison version '{version_string}' of clause '{string}' does not conform to SemVer." + raise ValueError(msg) + + operand = version_string + else: + # No operator. Check if string is a version, since we allow "==" to be omitted. + if not _SEMVER_PATTERN.match(string): + msg = ( + f"Version specifier clause '{string}' does not start with a valid operator and isn't a " + f"version itself. Valid operators are {', '.join(_OPERATORS)}." + ) + raise ValueError(msg) + + operator = "==" + operand = string + + return cls(operator, operand) + + def __str__(self) -> str: + return f"{self.operator} {self.operand}" + + # Dict because we want to preserve order (for readability) but not compare order or allow dupes. + _clauses: dict[Clause, None] + + def __init__(self, clauses: Iterable[Clause]) -> None: + super().__setattr__("_clauses", dict.fromkeys(clauses)) + + @property + def clauses(self) -> tuple[Clause, ...]: + return tuple(self._clauses) + + def __str__(self) -> str: + return ", ".join(map(str, self._clauses)) + + @classmethod + def from_string(cls, string: str) -> Self: + return cls( + tuple(cls.Clause.from_string(clause) for clause in string.split(",") if string and not string.isspace()) + ) + + def allows(self, version: VersionProtocol) -> bool: + """Checks if _all_ clauses allow the given version.""" + return all(clause.allows(version) for clause in self._clauses) + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + return core_schema.json_or_python_schema( + core_schema.no_info_after_validator_function(cls.from_string, handler(str)), + core_schema.is_instance_schema(cls), + serialization=core_schema.to_string_ser_schema(), + ) diff --git a/questionpy_server/collector/_package_collection.py b/questionpy_server/collector/_package_collection.py index 08483f3c..73cc5dba 100644 --- a/questionpy_server/collector/_package_collection.py +++ b/questionpy_server/collector/_package_collection.py @@ -7,16 +7,15 @@ from pathlib import Path from typing import TYPE_CHECKING +import semver from pydantic import HttpUrl -from questionpy_server import WorkerPool from questionpy_server.cache import LRUCache from questionpy_server.collector.indexer import Indexer from questionpy_server.collector.lms_collector import LMSCollector from questionpy_server.collector.local_collector import LocalCollector from questionpy_server.collector.repo_collector import RepoCollector from questionpy_server.models import PackageVersionsInfo -from questionpy_server.utils.manifest import SemVer if TYPE_CHECKING: from questionpy_server.collector.abc import BaseCollector @@ -33,9 +32,8 @@ def __init__( repos: dict[HttpUrl, timedelta], repo_index_cache: LRUCache, package_cache: LRUCache, - worker_pool: WorkerPool, ): - self._indexer = Indexer(worker_pool) + self._indexer = Indexer() self._collectors: list[BaseCollector] = [] if local_dir: @@ -89,7 +87,7 @@ def get(self, package_hash: str) -> "Package | None": """ return self._indexer.get_by_hash(package_hash) - def get_by_identifier(self, identifier: str) -> dict[SemVer, "Package"]: + def get_by_identifier(self, identifier: str) -> dict[semver.Version, "Package"]: """Returns a dict of packages with the given identifier and available versions. Args: @@ -100,7 +98,7 @@ def get_by_identifier(self, identifier: str) -> dict[SemVer, "Package"]: """ return self._indexer.get_by_identifier(identifier) - def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> "Package | None": + def get_by_identifier_and_version(self, identifier: str, version: semver.Version) -> "Package | None": """Returns a package with the given identifier and version. Args: diff --git a/questionpy_server/collector/indexer.py b/questionpy_server/collector/indexer.py index 49803b62..41978121 100644 --- a/questionpy_server/collector/indexer.py +++ b/questionpy_server/collector/indexer.py @@ -7,13 +7,15 @@ from pathlib import Path from typing import overload -from questionpy_server import WorkerPool +import semver +from semver import VersionInfo as _Version + from questionpy_server.collector.abc import BaseCollector from questionpy_server.collector.local_collector import LocalCollector from questionpy_server.collector.repo_collector import RepoCollector from questionpy_server.models import PackageInfo, PackageVersionsInfo, PackageVersionSpecificInfo from questionpy_server.package import Package -from questionpy_server.utils.manifest import ComparableManifest, SemVer, read_manifest +from questionpy_server.utils.manifest import Manifest, read_manifest_from_zip class Indexer: @@ -23,11 +25,9 @@ class Indexer: only indexed by its hash. """ - def __init__(self, worker_pool: WorkerPool): - self._worker_pool = worker_pool - + def __init__(self) -> None: self._index_by_hash: dict[str, Package] = {} - self._index_by_identifier: dict[str, dict[SemVer, Package]] = {} + self._index_by_identifier: dict[str, dict[semver.Version, Package]] = {} """dict[identifier, dict[version, Package]]""" self._package_versions_infos: list[PackageVersionsInfo] | None = None @@ -45,7 +45,7 @@ def get_by_hash(self, package_hash: str) -> Package | None: """ return self._index_by_hash.get(package_hash, None) - def get_by_identifier(self, identifier: str) -> dict[SemVer, Package]: + def get_by_identifier(self, identifier: str) -> dict[semver.Version, Package]: """Returns a dict of packages with the given identifier and available versions. Args: @@ -56,7 +56,7 @@ def get_by_identifier(self, identifier: str) -> dict[SemVer, Package]: """ return self._index_by_identifier.get(identifier, {}).copy() - def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> Package | None: + def get_by_identifier_and_version(self, identifier: str, version: semver.Version) -> Package | None: """Returns the package with the given identifier and version or None if it does not exist. Args: @@ -100,20 +100,20 @@ def get_package_versions_infos(self) -> list[PackageVersionsInfo]: @overload async def register_package( - self, package_hash: str, path_or_manifest: ComparableManifest, source: BaseCollector + self, package_hash: str, path_or_manifest: Manifest, source: BaseCollector ) -> Package: ... @overload async def register_package(self, package_hash: str, path_or_manifest: Path, source: BaseCollector) -> Package: ... async def register_package( - self, package_hash: str, path_or_manifest: Path | ComparableManifest, source: BaseCollector + self, package_hash: str, path_or_manifest: Path | Manifest, source: BaseCollector ) -> Package: """Registers a package in the index. Args: package_hash (str): The hash of the package. - path_or_manifest (Union[Path, ComparableManifest]): The manifest of the package. + path_or_manifest (Union[Path, Manifest]): The manifest of the package. source (BaseCollector): The source of the package. Raises: @@ -130,7 +130,7 @@ async def register_package( # Create new package... if isinstance(path_or_manifest, Path): # ...from path. - manifest = await read_manifest(path_or_manifest) + manifest = await read_manifest_from_zip(path_or_manifest) package = Package(package_hash, manifest, source, path_or_manifest) else: # ...from manifest. @@ -140,7 +140,8 @@ async def register_package( # Check if package should be accessible by identifier and version. if isinstance(source, LocalCollector | RepoCollector): package_versions = self._index_by_identifier.setdefault(package.manifest.identifier, {}) - existing_package = package_versions.get(package.manifest.version, None) + comparable_manifest = package.manifest + existing_package = package_versions.get(_Version.parse(comparable_manifest.version), None) if existing_package and existing_package != package: # Package with the same identifier and version already exists; log a warning. log = logging.getLogger("questionpy-server:indexer") @@ -152,7 +153,8 @@ async def register_package( existing_package.hash, ) else: - package_versions[package.manifest.version] = package + manifest1 = package.manifest + package_versions[_Version.parse(manifest1.version)] = package # Force recalculation of list[PackageVersionsInfo]. self._package_versions_infos = None @@ -182,7 +184,8 @@ async def unregister_package(self, package_hash: str, source: BaseCollector) -> package_versions = self._index_by_identifier.get(package.manifest.identifier, None) if package_versions: # Remove package from index. - package_versions.pop(package.manifest.version, None) + manifest = package.manifest + package_versions.pop(_Version.parse(manifest.version), None) # If there are no more packages with the same identifier, remove the identifier from the index. if not package_versions: self._index_by_identifier.pop(package.manifest.identifier, None) diff --git a/questionpy_server/dependencies/__init__.py b/questionpy_server/dependencies/__init__.py new file mode 100644 index 00000000..63ebc58a --- /dev/null +++ b/questionpy_server/dependencies/__init__.py @@ -0,0 +1,27 @@ +from ._dynamic_resolver_abc import ( + AvailablePackageVersion, + DynamicDependencyResolver, + NoopDependencyResolver, + NoPackageWithHashError, +) +from ._package_collection_adapter import PackageCollectionDependencyResolver +from ._solver import resolve_dependency_tree +from ._solver.errors import ( + DependencyConflictError, + DependencyCycleError, + QPyDependencyError, + TooDeeplyNestedDependencyError, +) + +__all__ = [ + "AvailablePackageVersion", + "DependencyConflictError", + "DependencyCycleError", + "DynamicDependencyResolver", + "NoPackageWithHashError", + "NoopDependencyResolver", + "PackageCollectionDependencyResolver", + "QPyDependencyError", + "TooDeeplyNestedDependencyError", + "resolve_dependency_tree", +] diff --git a/questionpy_server/dependencies/_dynamic_resolver_abc.py b/questionpy_server/dependencies/_dynamic_resolver_abc.py new file mode 100644 index 00000000..fd366af5 --- /dev/null +++ b/questionpy_server/dependencies/_dynamic_resolver_abc.py @@ -0,0 +1,72 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import dataclass +from typing import final + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.package_location import PackageLocation +from questionpy_common.version_specifiers import QPyDependencyVersionSpecifier +from questionpy_server.utils.manifest import Manifest, ParsableSemverVersion + +_log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class AvailablePackageVersion: + manifest: Manifest + hash: str + version: ParsableSemverVersion + + +class NoPackageWithHashError(Exception): + def __init__(self, hash_: str) -> None: + msg = f"Previously found package with hash '{hash_}' cannot be retrieved." + super().__init__(msg) + + +class DynamicDependencyResolver(ABC): + """Finds packages matching a given dynamic dependency. + + This is implemented using the `PackageCollection` in the server, but we also intend for the solver to be used by the + SDK at build-time, so this is an ABC. + """ + + @abstractmethod + def get_matching_versions( + self, + nssn: PackageNamespaceAndShortName, + version_spec: QPyDependencyVersionSpecifier | None, + *, + include_prereleases: bool, + ) -> Iterable[AvailablePackageVersion]: + """Returns all package versions matching the given restrictions, in any order. + + When this resolver has no packages matching the given restrictions, no error is thrown, just an empty iterable + returned. + """ + + @abstractmethod + async def get_package_location(self, hash_: str) -> PackageLocation: + """Gets a location for a package that was previously returned by `get_matching_versions`. + + Raises: + NoPackageWithHashError + """ + + +@final +class NoopDependencyResolver(DynamicDependencyResolver): + """A dependency resolver that does not provide any dependencies.""" + + def get_matching_versions( + self, + nssn: PackageNamespaceAndShortName, + version_spec: QPyDependencyVersionSpecifier | None, + *, + include_prereleases: bool, + ) -> Iterable[AvailablePackageVersion]: + return () + + async def get_package_location(self, hash_: str) -> PackageLocation: + raise NoPackageWithHashError(hash_) diff --git a/questionpy_server/dependencies/_package_collection_adapter.py b/questionpy_server/dependencies/_package_collection_adapter.py new file mode 100644 index 00000000..ff4b4e1c --- /dev/null +++ b/questionpy_server/dependencies/_package_collection_adapter.py @@ -0,0 +1,42 @@ +from collections.abc import Iterable + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.package_location import PackageLocation +from questionpy_common.version_specifiers import QPyDependencyVersionSpecifier +from questionpy_server.collector import PackageCollection + +from ._dynamic_resolver_abc import ( + AvailablePackageVersion, + DynamicDependencyResolver, + NoPackageWithHashError, +) + + +class PackageCollectionDependencyResolver(DynamicDependencyResolver): + """A dependency resolver backed by the QPy server's `PackageCollection`.""" + + def __init__(self, package_collection: PackageCollection) -> None: + self._package_collection = package_collection + + def get_matching_versions( + self, + nssn: PackageNamespaceAndShortName, + version_spec: QPyDependencyVersionSpecifier | None, + *, + include_prereleases: bool, + ) -> Iterable[AvailablePackageVersion]: + versions = self._package_collection.get_by_identifier(str(nssn)) + + return tuple( + AvailablePackageVersion(package.manifest, package.hash, version) + for version, package in versions.items() + if (include_prereleases or version.prerelease is None) + and (version_spec is None or version_spec.allows(version)) + ) + + async def get_package_location(self, hash_: str) -> PackageLocation: + package = self._package_collection.get(hash_) + if not package: + raise NoPackageWithHashError(hash_) + + return await package.get_zip_package_location() diff --git a/questionpy_server/dependencies/_solver/__init__.py b/questionpy_server/dependencies/_solver/__init__.py new file mode 100644 index 00000000..cc705cec --- /dev/null +++ b/questionpy_server/dependencies/_solver/__init__.py @@ -0,0 +1,74 @@ +from collections.abc import Sequence + +import resolvelib +from resolvelib import ResolutionImpossible, ResolutionTooDeep +from semver import Version + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.dependencies import DependencySolution +from questionpy_common.manifest import ( + DistDependencies, + DistQPyDependency, + SourceManifest, +) +from questionpy_server.dependencies._dynamic_resolver_abc import DynamicDependencyResolver + +from ._model import RootRequirementAndCandidate +from ._provider import QPyResolvelibProvider +from ._reporter import QPyResolvelibReporter +from .errors import DependencyConflictError, QPyDependencyError + + +def resolve_dependency_tree( + root: SourceManifest, + root_dependencies: Sequence[DistQPyDependency], + dynamic_dep_resolver: DynamicDependencyResolver, +) -> dict[PackageNamespaceAndShortName, DependencySolution]: + """Finds solutions for all packages in the dependency tree of the given root package. + + Raises: + DependencyConflictError: If different requirements for a package conflict. + TooDeeplyNestedDependencyError: If MAX_QPY_DEPENDENCY_LEVELS is exceeded. + DependencyCycleError: If there is a cycle in the dependency tree after resolution. + QPyDependencyError: If the dependency tree cannot be resolved at all. (Which may also be due to some kinds of + cycles.) + """ + provider = QPyResolvelibProvider(dynamic_dep_resolver) + reporter = QPyResolvelibReporter(root.nssn) + resolver = resolvelib.Resolver(provider, reporter) + + root_node = RootRequirementAndCandidate( + root.nssn, Version.parse(root.version), DistDependencies(qpy=list(root_dependencies)) + ) + + try: + result = resolver.resolve((root_node,)) + except Exception as e: + reporter.log_failed_resolution(e) + + if isinstance(e, QPyDependencyError): + raise + + if isinstance(e, ResolutionImpossible): + causes = tuple(e.causes) + if not causes: + msg = "ResolutionImpossible without causes" + raise QPyDependencyError(msg) from e + + nssn = causes[0].requirement.nssn + raise DependencyConflictError(nssn, causes) from e + + if isinstance(e, ResolutionTooDeep): + msg = f"Dependency resolution took too many ({e.round_count}) rounds." + raise QPyDependencyError(msg) from e + + msg = "Unknown resolution error" + raise QPyDependencyError(msg) from e + else: + reporter.log_successful_resolution() + + return { + nssn: candidate + for nssn, candidate in result.mapping.items() + if not isinstance(candidate, RootRequirementAndCandidate) + } diff --git a/questionpy_server/dependencies/_solver/_model.py b/questionpy_server/dependencies/_solver/_model.py new file mode 100644 index 00000000..fc0689cd --- /dev/null +++ b/questionpy_server/dependencies/_solver/_model.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass + +from semver import Version + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.dependencies import DynamicDependencySolution, StaticDependencySolution +from questionpy_common.manifest import AbstractDynamicQPyDependency, DistDependencies, DistStaticQPyDependency + + +@dataclass(frozen=True) +class DynamicRequirement: + dep: AbstractDynamicQPyDependency + + @property + def nssn(self) -> PackageNamespaceAndShortName: + return self.dep.nssn + + def __str__(self) -> str: + string = str(self.dep.version) if self.dep.version else "any version" + string += " (including prereleases)" if self.dep.include_prereleases else " (excluding prereleases)" + return string + + +@dataclass(frozen=True) +class StaticRequirement: + owner: PackageNamespaceAndShortName + dep: DistStaticQPyDependency + + @property + def nssn(self) -> PackageNamespaceAndShortName: + return self.dep.nssn + + def __str__(self) -> str: + return f"statically packaged version {self.dep.version} ({self.dep.hash})" + + +@dataclass(frozen=True) +class RootRequirementAndCandidate: + """Represents the root package as both a requirement and a candidate.""" + + nssn: PackageNamespaceAndShortName + version: Version + dependencies: DistDependencies + + def __str__(self) -> str: + return f"root package ({self.nssn}:{self.version})" + + +type Requirement = DynamicRequirement | StaticRequirement | RootRequirementAndCandidate + + +type DynamicCandidate = DynamicDependencySolution +type StaticCandidate = StaticDependencySolution + +type Candidate = DynamicCandidate | StaticCandidate | RootRequirementAndCandidate diff --git a/questionpy_server/dependencies/_solver/_provider.py b/questionpy_server/dependencies/_solver/_provider.py new file mode 100644 index 00000000..43d83ee8 --- /dev/null +++ b/questionpy_server/dependencies/_solver/_provider.py @@ -0,0 +1,248 @@ +from collections.abc import Iterable, Iterator, Mapping, Sequence + +import resolvelib +from resolvelib.structs import Matches, RequirementInformation +from semver import Version + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.dependencies import DynamicDependencySolution, StaticDependencySolution +from questionpy_common.manifest import AbstractDynamicQPyDependency, DistDynamicQPyDependency, DistStaticQPyDependency +from questionpy_common.version_specifiers import QPyDependencyVersionSpecifier +from questionpy_server.dependencies._dynamic_resolver_abc import ( + DynamicDependencyResolver, +) + +from ._model import ( + Candidate, + DynamicRequirement, + Requirement, + RootRequirementAndCandidate, + StaticRequirement, +) + + +class _MergedDynamicDep(AbstractDynamicQPyDependency): + pass + + +def _merge_dynamic_deps(dep: AbstractDynamicQPyDependency, *deps: AbstractDynamicQPyDependency) -> _MergedDynamicDep: + clauses = list(dep.version.clauses) if dep.version else [] + include_prereleases = dep.include_prereleases + + for other_dep in deps: + if isinstance(other_dep, AbstractDynamicQPyDependency): + if other_dep.version: + clauses.extend(other_dep.version.clauses) + include_prereleases &= other_dep.include_prereleases + + return _MergedDynamicDep( + namespace=dep.namespace, + short_name=dep.short_name, + version=QPyDependencyVersionSpecifier(clauses) if clauses else None, + include_prereleases=include_prereleases, + ) + + +def _partition_reqs( + reqs: Iterable[Requirement], +) -> tuple[Sequence[DynamicRequirement], Sequence[StaticRequirement], RootRequirementAndCandidate | None]: + dynamic: list[DynamicRequirement] = [] + static: list[StaticRequirement] = [] + root: RootRequirementAndCandidate | None = None + + for req in reqs: + if isinstance(req, DynamicRequirement): + dynamic.append(req) + elif isinstance(req, StaticRequirement): + static.append(req) + else: + root = req + + return dynamic, static, root + + +def _do_dynamic_reqs_allow_candidate(dynamic_reqs: Sequence[DynamicRequirement], cand_version: str | Version) -> bool: + if isinstance(cand_version, str): + cand_version = Version.parse(cand_version) + + for dynamic_req in dynamic_reqs: + allows = (dynamic_req.dep.include_prereleases or cand_version.prerelease is None) and ( + dynamic_req.dep.version is None or dynamic_req.dep.version.allows(cand_version) + ) + if not allows: + return False + + return True + + +def _find_static_matches( + nssn: PackageNamespaceAndShortName, + static_reqs: Sequence[StaticRequirement], + dynamic_reqs: Sequence[DynamicRequirement], +) -> list[Candidate]: + """When one or more static requirements exist for a package, check that they're the same and return solutions.""" + for static_req in static_reqs[1:]: + # We only compare the hash, since future changes in the manifest format might lead to inconsequential + # differences between the 'dependencies' fields. + if static_req.dep.hash != static_reqs[0].dep.hash: + # There are multiple _different_ static versions of the dependency required. + return [] + + # All the static dependencies are equivalent. + + if not _do_dynamic_reqs_allow_candidate(dynamic_reqs, static_reqs[0].dep.version): + # At least one dynamic dependency does not allow the static version. + return [] + + return [ + StaticDependencySolution( + nssn=nssn, + owner=static_req.owner, + hash=static_req.dep.hash, + version=static_req.dep.version, + dependencies=static_req.dep.dependencies, + ) + for static_req in static_reqs + ] + + +def _find_dynamic_matches( + nssn: PackageNamespaceAndShortName, dynamic_reqs: Sequence[DynamicRequirement], resolver: DynamicDependencyResolver +) -> list[Candidate]: + """When only dynamic requirements exist for a package, find all matching available package versions.""" + merged = _merge_dynamic_deps(*(req.dep for req in dynamic_reqs)) + + # TODO: Use locked version if possible. + # We sort from highest (i.e. latest) version to lowest (i.e. oldest), since resolvelib tries candidates in order. + matching_package_versions = sorted( + resolver.get_matching_versions( + nssn=nssn, + version_spec=merged.version, + include_prereleases=merged.include_prereleases, + ), + key=lambda apv: apv.version, + reverse=True, + ) + + return [ + DynamicDependencySolution( + nssn=nssn, + hash=apv.hash, + version=apv.manifest.version, + dependencies=apv.manifest.dependencies, + ) + for apv in matching_package_versions + ] + + +class QPyResolvelibProvider(resolvelib.AbstractProvider[Requirement, Candidate, PackageNamespaceAndShortName]): + def __init__(self, dynamic_resolver: DynamicDependencyResolver) -> None: + self._dynamic_resolver = dynamic_resolver + + def identify(self, requirement_or_candidate: Requirement | Candidate) -> PackageNamespaceAndShortName: + return requirement_or_candidate.nssn + + def get_preference( + self, + identifier: PackageNamespaceAndShortName, + resolutions: Mapping[PackageNamespaceAndShortName, Candidate], + candidates: Mapping[PackageNamespaceAndShortName, Iterator[Candidate]], + information: Mapping[PackageNamespaceAndShortName, Iterator[RequirementInformation[Requirement, Candidate]]], + backtrack_causes: Sequence[RequirementInformation[Requirement, Candidate]], + ) -> tuple[bool, bool, bool, bool, PackageNamespaceAndShortName]: + # This method only serves to optimize the resolution by resolving more restricted packages first. + # In our case, we resolve in the following order: + # - root package + # - static dependencies + # - dynamic dependencies with at least one "==" constraint + # - dynamic dependencies with any constraints + # - dynamic dependencies without constraints + # + # Within those groups, we use alphabetical order, for consistency. + # This strategy is inspired by pip. + + dynamic_reqs, static_reqs, root_req = _partition_reqs( + info.requirement for info in information.get(identifier, ()) + ) + + is_root = root_req is not None + is_static = len(static_reqs) > 0 + + if dynamic_reqs: + merged = _merge_dynamic_deps(*(req.dep for req in dynamic_reqs)) + is_pinned = any(clause.operator == "==" for clause in merged.version.clauses) if merged.version else False + is_restricted = merged.version is not None and len(merged.version.clauses) > 0 + else: + is_pinned = False + is_restricted = False + + return ( + not is_root, + not is_static, + not is_pinned, + not is_restricted, + identifier, + ) + + def find_matches( + self, + identifier: PackageNamespaceAndShortName, + requirements: Mapping[PackageNamespaceAndShortName, Iterator[Requirement]], + incompatibilities: Mapping[PackageNamespaceAndShortName, Iterator[Candidate]], + ) -> Matches[Candidate]: + reqs = tuple(requirements.get(identifier, ())) + incompatible_candidates = tuple(incompatibilities.get(identifier, ())) + + if not reqs: + msg = f"There is no requirement on '{identifier}', why are we resolving it?" + raise RuntimeError(msg) + + dynamic_reqs, static_reqs, root_req = _partition_reqs(reqs) + + if root_req: + if root_req in incompatible_candidates: + # The root requirement has for some reason been marked as incompatible in a previous backtracking round. + return () + if static_reqs: + # There is also a static dependency on the root package, which isn't allowed. + return () + + # If there is a dynamic dependency on the root package, it's always a cycle. + # We could return () in that case, but letting the cycle check later on handle this will lead to a better + # error message than we could generate here. + if not _do_dynamic_reqs_allow_candidate(dynamic_reqs, root_req.version): + # Of course, if the version doesn't match, we still prevent it. + return () + + return (root_req,) + + if static_reqs: + matches = _find_static_matches(identifier, static_reqs, dynamic_reqs) + else: + # Only dynamic dependencies for this NSSN have so far been discovered. + matches = _find_dynamic_matches(identifier, dynamic_reqs, self._dynamic_resolver) + + for incompatible_candidate in incompatible_candidates: + matches.remove(incompatible_candidate) + + return matches + + def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool: + if isinstance(requirement, StaticRequirement): + # Static requirements are only satisfied by static candidates. As long as the hashes match, the owner + # doesn't matter. + return isinstance(candidate, StaticDependencySolution) and candidate.hash == requirement.dep.hash + + # The root requirement is only satisfied by the root candidate. + if isinstance(requirement, RootRequirementAndCandidate): + return requirement == candidate + + # Dynamic requirements can be satisfied by any kind of candidate so long as the versions match. + return _do_dynamic_reqs_allow_candidate((requirement,), candidate.version) + + def get_dependencies(self, candidate: Candidate) -> Iterable[Requirement]: + for dep in candidate.dependencies.qpy: + if isinstance(dep, DistDynamicQPyDependency): + yield DynamicRequirement(dep) + elif isinstance(dep, DistStaticQPyDependency): + yield StaticRequirement(candidate.nssn, dep) diff --git a/questionpy_server/dependencies/_solver/_reporter.py b/questionpy_server/dependencies/_solver/_reporter.py new file mode 100644 index 00000000..b19b6ce9 --- /dev/null +++ b/questionpy_server/dependencies/_solver/_reporter.py @@ -0,0 +1,125 @@ +import logging +from collections.abc import Mapping +from typing import NamedTuple + +import resolvelib +from resolvelib.resolvers import Criterion +from resolvelib.structs import State + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.constants import MAX_QPY_DEPENDENCY_LEVELS + +from ._model import ( + Candidate, + Requirement, + RootRequirementAndCandidate, +) +from .errors import DependencyCycleError, TooDeeplyNestedDependencyError + + +def _tree_path_to_str(path: tuple[PackageNamespaceAndShortName, ...]) -> str: + return " -> ".join(map(str, path)) + + +class QPyResolvelibReporter(resolvelib.BaseReporter[Requirement, Candidate, PackageNamespaceAndShortName]): + def __init__(self, root_nssn: PackageNamespaceAndShortName) -> None: + self._root_nssn = root_nssn + self._logger = logging.getLogger(__name__) + + self._messages: list[str] = [] + + def pinning(self, candidate: Candidate) -> None: + if isinstance(candidate, RootRequirementAndCandidate): + # Not very interesting... + return + + # The solution classes have decent __str__ methods. + self._messages.append(f"Tentatively resolved '{candidate.nssn}' to {candidate}.") + + def rejecting_candidate(self, criterion: Criterion[Requirement, Candidate], candidate: Candidate) -> None: + self._messages.append(f"Rejecting previously pinned resolution of '{candidate.nssn}' to {candidate}.") + + def _format_messages(self) -> str: + return "\n".join(f"\t- {message}" for message in self._messages) + + def ending(self, state: State[Requirement, Candidate, PackageNamespaceAndShortName]) -> None: + # Resolvelib can deal with cycles in some cases, but we (the package initialization in the worker) cannot. + cycle, longest_path = _find_cycle_and_longest_path(state.mapping, self._root_nssn) + + if cycle is None: + self._messages.append("The tree was successfully resolved to a consistent and acyclic graph.") + else: + cycle_as_str = _tree_path_to_str(cycle) + # log_failed_resolution will be called by the exception handler. + self._messages.append( + f"The tree was resolved to a consistent graph, but it contains a cycle: {cycle_as_str}" + ) + + raise DependencyCycleError(cycle) + + if len(longest_path) > MAX_QPY_DEPENDENCY_LEVELS: + raise TooDeeplyNestedDependencyError(longest_path) + + def log_successful_resolution(self) -> None: + # The resolver automatically calls the 'ending' method, but we want to check + # This is called by the resolver automatically when the resolution ends successfully. + + summary = f"Successfully resolved dependency tree of package '{self._root_nssn}'." + + if self._logger.isEnabledFor(logging.DEBUG): + # Detailed output. + self._logger.debug("%s The following steps were taken:\n%s", summary, self._format_messages()) + else: + # Just the summary. + self._logger.info(summary) + + def log_failed_resolution(self, error: Exception) -> None: + # This isn't called by the resolver, we call it when the resolution raises an exception. + + if self._logger.isEnabledFor(logging.INFO): + self._logger.info( + "Failed to resolve dependency tree of package '%s'. The following steps were taken:\n%s", + self._root_nssn, + self._format_messages(), + exc_info=error, + ) + + +class _GraphCycleAndLongestPath(NamedTuple): + first_cycle: tuple[PackageNamespaceAndShortName, ...] | None + longest_path: tuple[PackageNamespaceAndShortName, ...] + + +def _find_cycle_and_longest_path( + mapping: Mapping[PackageNamespaceAndShortName, Candidate], root_nssn: PackageNamespaceAndShortName +) -> _GraphCycleAndLongestPath: + """Performs a depth-first search to find the first cycle and the longest path starting from `root_nssn`.""" + seen = set[PackageNamespaceAndShortName]() + longest_path: tuple[PackageNamespaceAndShortName, ...] = (root_nssn,) + + def recursive_dfs( + path: tuple[PackageNamespaceAndShortName, ...], + ) -> tuple[PackageNamespaceAndShortName, ...] | None: + node = path[-1] + seen.add(node) + candidate = mapping[node] + + for dep in candidate.dependencies.qpy: + new_path = (*path, dep.nssn) + + nonlocal longest_path + if len(new_path) > len(longest_path): + longest_path = new_path + + if len(set(new_path)) != len(new_path): + # There is a cycle up to this dependency. + return new_path + + if dep.nssn not in seen: + cycle = recursive_dfs(new_path) + if cycle is not None: + return cycle + + return None + + return _GraphCycleAndLongestPath(recursive_dfs((root_nssn,)), longest_path) diff --git a/questionpy_server/dependencies/_solver/errors.py b/questionpy_server/dependencies/_solver/errors.py new file mode 100644 index 00000000..6c10155b --- /dev/null +++ b/questionpy_server/dependencies/_solver/errors.py @@ -0,0 +1,56 @@ +from collections.abc import Iterable + +from resolvelib.structs import RequirementInformation + +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.constants import MAX_QPY_DEPENDENCY_LEVELS +from questionpy_common.manifest import AbstractDynamicQPyDependency, DistStaticQPyDependency + +from ._model import Candidate, Requirement, RootRequirementAndCandidate + + +def _dep_version_to_str(dep: AbstractDynamicQPyDependency | DistStaticQPyDependency) -> str: + if isinstance(dep, AbstractDynamicQPyDependency): + string = str(dep.version) if dep.version else "any version" + string += " (including prereleases)" if dep.include_prereleases else " (excluding prereleases)" + return string + + return f"statically packaged version {dep.version} ({dep.hash})" + + +def _tree_path_to_str(path: tuple[PackageNamespaceAndShortName, ...]) -> str: + return " -> ".join(map(str, path)) + + +class QPyDependencyError(Exception): + pass + + +class DependencyConflictError(QPyDependencyError): + def __init__( + self, nssn: PackageNamespaceAndShortName, causes: Iterable[RequirementInformation[Requirement, Candidate]] + ) -> None: + msg = f"No version of '{nssn}' could be found that satisfies all of the following dependencies:" + for req, parent in causes: + if isinstance(req, RootRequirementAndCandidate): + msg += f"\n\t- root package ({req.version})" + else: + parent_str = f"'{parent.nssn}:{parent.version}'" if parent else "unexpected top-level requirement" + msg += f"\n\t- via {parent_str}: {_dep_version_to_str(req.dep)}" + + super().__init__(msg) + self.nssn = nssn + + +class DependencyCycleError(QPyDependencyError): + def __init__(self, path: tuple[PackageNamespaceAndShortName, ...]) -> None: + msg = f"The dependency tree contains a cycle: {_tree_path_to_str(path)}" + super().__init__(msg) + + self.path = path + + +class TooDeeplyNestedDependencyError(QPyDependencyError): + def __init__(self, path: tuple[PackageNamespaceAndShortName, ...]) -> None: + msg = f"Dependency graph is deeper than {MAX_QPY_DEPENDENCY_LEVELS} levels at {_tree_path_to_str(path)}." + super().__init__(msg) diff --git a/questionpy_server/package.py b/questionpy_server/package.py index f767985d..6718d629 100644 --- a/questionpy_server/package.py +++ b/questionpy_server/package.py @@ -6,13 +6,13 @@ from pathlib import Path from typing import TYPE_CHECKING +from questionpy_common.package_location import ZipPackageLocation from questionpy_server.collector.abc import BaseCollector from questionpy_server.collector.lms_collector import LMSCollector from questionpy_server.collector.local_collector import LocalCollector from questionpy_server.collector.repo_collector import RepoCollector from questionpy_server.models import PackageVersionInfo -from questionpy_server.utils.manifest import ComparableManifest -from questionpy_server.worker.runtime.package_location import ZipPackageLocation +from questionpy_server.utils.manifest import Manifest if TYPE_CHECKING: from questionpy_server.collector.abc import BaseCollector @@ -103,7 +103,7 @@ def is_local(self) -> bool: class Package: hash: str - manifest: ComparableManifest + manifest: Manifest sources: PackageSources @@ -113,7 +113,7 @@ class Package: def __init__( self, package_hash: str, - manifest: ComparableManifest, + manifest: Manifest, source: "BaseCollector | None" = None, path: Path | None = None, ): diff --git a/questionpy_server/repository/models.py b/questionpy_server/repository/models.py index 833c22b7..7f7db027 100644 --- a/questionpy_server/repository/models.py +++ b/questionpy_server/repository/models.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, PositiveInt -from questionpy_server.utils.manifest import ComparableManifest, SemVer +from questionpy_server.utils.manifest import Manifest, ParsableSemverVersion class RepoMeta(BaseModel): @@ -28,7 +28,7 @@ class RepoMeta(BaseModel): class RepoPackageVersion(BaseModel): """Represents a specific version of a package in the repository.""" - version: SemVer + version: ParsableSemverVersion """Version of the package.""" api_version: str """Compatible API version of the package.""" @@ -57,7 +57,7 @@ def __hash__(self) -> int: class RepoPackageVersions(BaseModel): """Represents a package with all its versions in the repository.""" - manifest: ComparableManifest + manifest: Manifest """Manifest of the most recent version of the package.""" versions: list[RepoPackageVersion] """List of all versions of the package.""" @@ -74,7 +74,7 @@ class RepoPackageIndex(BaseModel): class RepoPackage: """Represents a package in the repository.""" - manifest: ComparableManifest + manifest: Manifest """Manifest of the package.""" path: str @@ -85,7 +85,7 @@ class RepoPackage: """SHA256 hash of the package.""" @classmethod - def combine(cls, manifest: ComparableManifest, repo_package_version: RepoPackageVersion) -> "RepoPackage": + def combine(cls, manifest: Manifest, repo_package_version: RepoPackageVersion) -> "RepoPackage": """Combines the manifest of a package with a specific version of that package. Args: @@ -94,7 +94,7 @@ def combine(cls, manifest: ComparableManifest, repo_package_version: RepoPackage """ # Replace package version and api version with actual versions. modified_manifest = manifest.model_copy(deep=True) - modified_manifest.version = repo_package_version.version + modified_manifest.version = str(repo_package_version.version) modified_manifest.api_version = repo_package_version.api_version return cls( diff --git a/questionpy_server/utils/manifest.py b/questionpy_server/utils/manifest.py index 0be0cf7d..96e9007c 100644 --- a/questionpy_server/utils/manifest.py +++ b/questionpy_server/utils/manifest.py @@ -3,38 +3,61 @@ # (c) Technische Universität Berlin, innoCampus import asyncio +from asyncio import to_thread +from contextlib import ExitStack from pathlib import Path -from typing import Annotated +from typing import IO, Annotated, Any from zipfile import BadZipFile, ZipFile from pydantic import PlainSerializer, PlainValidator, ValidationError -from semver import VersionInfo as _Version +from semver import Version from questionpy_common.constants import DIST_DIR, MANIFEST_FILENAME, MAX_MANIFEST_SIZE from questionpy_common.error import QPyBaseError from questionpy_common.manifest import Manifest +from questionpy_common.package_location import ( + FunctionPackageLocation, + PackageLocation, + ZipPackageLocation, +) -type SemVer = Annotated[_Version, PlainValidator(_Version.parse), PlainSerializer(_Version.__str__)] + +class ManifestError(QPyBaseError): + pass -class ComparableManifest(Manifest): - version: SemVer # type: ignore[assignment] +def _read_manifest_from_file_sync(manifest_file: IO[bytes]) -> Manifest: + try: + buffer = manifest_file.read(MAX_MANIFEST_SIZE + 1) + if len(buffer) > MAX_MANIFEST_SIZE: + msg = f"Manifest is too large. Maximum size is {MAX_MANIFEST_SIZE.human_readable()}." + raise ManifestError(msg) -class ManifestError(QPyBaseError): - pass + return Manifest.model_validate_json(buffer) + except ValidationError as e: + msg = f"Manifest is invalid: {e}" + raise ManifestError(msg) from e + + +def _read_manifest_from_path_sync(manifest_path: Path) -> Manifest: + try: + with manifest_path.open("rb") as file: + return _read_manifest_from_file_sync(file) + except FileNotFoundError as e: + msg = "Manifest is missing." + raise ManifestError(msg) from e -def _read_manifest_sync(package_path: Path) -> ComparableManifest: +def _read_manifest_from_zip_sync(package: Path | ZipFile) -> Manifest: try: - with ZipFile(package_path) as zip_file, zip_file.open(f"{DIST_DIR}/{MANIFEST_FILENAME}") as manifest_file: - buffer = manifest_file.read(MAX_MANIFEST_SIZE + 1) + with ExitStack() as stack: + if isinstance(package, Path): + package = stack.enter_context(ZipFile(package)) - if len(buffer) > MAX_MANIFEST_SIZE: - msg = f"Manifest is too large. Maximal size is {MAX_MANIFEST_SIZE.human_readable()}." - raise ManifestError(msg) + manifest_file = stack.enter_context(package.open(f"{DIST_DIR}/{MANIFEST_FILENAME}")) - return ComparableManifest.model_validate_json(buffer) + return _read_manifest_from_file_sync(manifest_file) except BadZipFile as e: msg = f"Could not read manifest from package: {e}" raise ManifestError(msg) from e @@ -42,15 +65,35 @@ def _read_manifest_sync(package_path: Path) -> ComparableManifest: # ZipFile.open raises a KeyError if the file does not exist. msg = "Manifest is missing." raise ManifestError(msg) from e - except ValidationError as e: - msg = f"Manifest is invalid: {e}" - raise ManifestError(msg) from e -async def read_manifest(package_path: Path) -> ComparableManifest: +async def read_manifest_from_zip(package: Path | ZipFile) -> Manifest: """Reads the manifest from a zipped package. Raises: - ManifestError: if the manifest could not be read, is too large, or is invalid + ManifestError: if the manifest could not be read, it is too large or is invalid """ - return await asyncio.to_thread(_read_manifest_sync, package_path) + return await asyncio.to_thread(_read_manifest_from_zip_sync, package) + + +async def read_manifest_from_location(location: PackageLocation) -> Manifest: + if isinstance(location, ZipPackageLocation): + return await read_manifest_from_zip(location.path) + if isinstance(location, FunctionPackageLocation): + return Manifest(**location.manifest.model_dump()) + + manifest_path = location.path / MANIFEST_FILENAME + return await to_thread(_read_manifest_from_path_sync, manifest_path) + + +def _maybe_parse_version(value: Any) -> Any: + if isinstance(value, Version): + return value + if isinstance(value, str): + return Version.parse(value) + return value + + +type ParsableSemverVersion = Annotated[ + Version, PlainValidator(_maybe_parse_version, json_schema_input_type=str), PlainSerializer(Version.__str__) +] diff --git a/questionpy_server/web/_routes/_files.py b/questionpy_server/web/_routes/_files.py index 62d4c10c..38080917 100644 --- a/questionpy_server/web/_routes/_files.py +++ b/questionpy_server/web/_routes/_files.py @@ -6,10 +6,8 @@ from aiohttp.web_exceptions import HTTPNotImplemented from questionpy_server.package import Package -from questionpy_server.web import CURRENT_USER_KEY from questionpy_server.web._decorators import ensure_package -from questionpy_server.web.app import QPyServer -from questionpy_server.worker.selector import SelectorQuery +from questionpy_server.web._worker_context import worker_context file_routes = web.RouteTableDef() @@ -17,7 +15,6 @@ @file_routes.post(r"/packages/{package_hash}/file/{namespace}/{short_name}/{path:static/.*}") @ensure_package async def serve_static_file(request: web.Request, package: Package) -> web.Response: - qpy_server = request.app[QPyServer.APP_KEY] namespace = request.match_info["namespace"] short_name = request.match_info["short_name"] path = request.match_info["path"] @@ -26,17 +23,9 @@ async def serve_static_file(request: web.Request, package: Package) -> web.Respo # TODO: Support static files in non-main packages by using namespace and short_name. raise HTTPNotImplemented(text="Static file retrieval from non-main packages is not supported yet.") - current_user = request.get(CURRENT_USER_KEY) - selector_query = SelectorQuery(package, current_user, "files") - permissions = qpy_server.package_permissions.get(selector_query) - environment_variables = qpy_server.environment_variables.get(selector_query) - location = await package.get_zip_package_location() - - async with qpy_server.worker_pool.get_worker( - location, current_user, "files", permissions, environment_variables - ) as worker: + async with worker_context(request, package, context="files") as context: try: - file = await worker.get_static_file(path) + file = await context.worker.get_static_file(path) except FileNotFoundError as e: raise web.HTTPNotFound(text="File not found.") from e diff --git a/questionpy_server/web/_worker_context.py b/questionpy_server/web/_worker_context.py index 1b31c5c7..b2a94576 100644 --- a/questionpy_server/web/_worker_context.py +++ b/questionpy_server/web/_worker_context.py @@ -1,6 +1,6 @@ from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from typing import NamedTuple +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, NamedTuple, overload from aiohttp import web @@ -32,13 +32,30 @@ class WorkerContext(NamedTuple): permissions: PackagePermissions +@overload +def worker_context( + request: web.Request, package: Package, *, context: str +) -> AbstractAsyncContextManager[WorkerContext]: ... + + +@overload +def worker_context( + request: web.Request, package: Package, data: RequestBaseData +) -> AbstractAsyncContextManager[WorkerContext]: ... + + @asynccontextmanager -async def worker_context(request: web.Request, package: Package, data: RequestBaseData) -> AsyncIterator[WorkerContext]: +async def worker_context( + request: web.Request, package: Package, data: Any = None, *, context: Any = None +) -> AsyncIterator[WorkerContext]: """Returns the worker context for the given request.""" + if data: + context = data.context + qpyserver = request.app[QPyServer.APP_KEY] current_user = request.get(CURRENT_USER_KEY) - selector_query = SelectorQuery(package, current_user, data.context) + selector_query = SelectorQuery(package, current_user, context) permissions = qpyserver.package_permissions.get(selector_query) environment_variables = qpyserver.environment_variables.get(selector_query) @@ -49,7 +66,7 @@ async def worker_context(request: web.Request, package: Package, data: RequestBa lms_provided_attributes = data.lms_provided_attributes async with qpyserver.worker_pool.get_worker( - location, current_user, data.context, permissions, environment_variables + location, current_user, context, permissions, environment_variables ) as worker: yield WorkerContext( worker, diff --git a/questionpy_server/web/app.py b/questionpy_server/web/app.py index 915d5b42..8d215055 100644 --- a/questionpy_server/web/app.py +++ b/questionpy_server/web/app.py @@ -12,6 +12,7 @@ from questionpy_server import __version__ from questionpy_server.cache import LRUCache, LRUCacheSupervisor from questionpy_server.collector import PackageCollection +from questionpy_server.dependencies import PackageCollectionDependencyResolver from questionpy_server.settings import Settings from questionpy_server.web.middlewares import middlewares from questionpy_server.worker.pool import WorkerPool @@ -33,9 +34,6 @@ def __init__(self, settings: Settings): self.web_app.add_routes(routes) self.web_app[self.APP_KEY] = self - self.worker_pool = WorkerPool( - settings.worker_pool.max_cpus, settings.worker_pool.max_memory, worker_type=settings.worker_pool.type - ) self.package_permissions = PackagePermissionsHandler(settings.permissions) self.environment_variables = EnvironmentVariablesHandler(settings.environment_variables) @@ -48,7 +46,14 @@ def __init__(self, settings: Settings): settings.collector.repositories, self.repo_index_cache, self.package_cache, - self.worker_pool, + ) + + worker_dependency_resolver = PackageCollectionDependencyResolver(self.package_collection) + self.worker_pool = WorkerPool( + settings.worker_pool.max_cpus, + settings.worker_pool.max_memory, + worker_type=settings.worker_pool.type, + dependency_resolver=worker_dependency_resolver, ) self.web_app.cleanup_ctx.append(self._worker_pool_ctx) diff --git a/questionpy_server/worker/__init__.py b/questionpy_server/worker/__init__.py index ece7f186..f5561e07 100644 --- a/questionpy_server/worker/__init__.py +++ b/questionpy_server/worker/__init__.py @@ -2,6 +2,7 @@ # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus from abc import ABC, abstractmethod +from collections.abc import Mapping from dataclasses import dataclass from enum import Enum from pathlib import Path @@ -9,15 +10,17 @@ from pydantic import BaseModel +from questionpy_common import PackageNamespaceAndShortName from questionpy_common.api.attempt import AttemptModel, AttemptScoredModel, AttemptStartedModel from questionpy_common.api.question import LmsPermissions +from questionpy_common.dependencies import SolutionAndLocation from questionpy_common.elements import OptionsFormDefinition from questionpy_common.environment import PackagePermissions, RequestInfo from questionpy_common.manifest import PackageFile +from questionpy_common.package_location import PackageLocation from questionpy_server.models import LoadedPackage, QuestionCreated -from questionpy_server.utils.manifest import ComparableManifest +from questionpy_server.utils.manifest import Manifest from questionpy_server.worker.runtime.messages import MessageToServer, MessageToWorker -from questionpy_server.worker.runtime.package_location import PackageLocation class WorkerResources(BaseModel): @@ -55,15 +58,22 @@ class PackageFileData: class WorkerArgs(TypedDict): name: str """A unique name given to the worker by its pool.""" + package: PackageLocation """The main package that the worker should load when [start][questionpy_server.worker.Worker.start] is called.""" + worker_home: Path """An existing directory owned by the worker, with the same lifetime as the worker.""" + permissions: PackagePermissions """The package permissions.""" + environment_variables: dict[str, str] """Environment variables to be set in the worker.""" + dependencies: Mapping[PackageNamespaceAndShortName, SolutionAndLocation] + """All resolved dependencies in the root package's tree. Does not include the root package itself.""" + class Worker(ABC): """Interface for worker implementations.""" @@ -75,6 +85,7 @@ def __init__(self, **kwargs: Unpack[WorkerArgs]) -> None: self.worker_home = kwargs["worker_home"] self.permissions = kwargs["permissions"] self.environment_variables = kwargs["environment_variables"] + self._dependencies = kwargs["dependencies"] self.state = WorkerState.NOT_RUNNING self.loaded_packages: list[LoadedPackage] = [] @@ -114,7 +125,7 @@ async def get_resource_usage(self) -> WorkerResources | None: """Get the worker's current resource usage. If unknown or unsupported, return None.""" @abstractmethod - async def get_manifest(self) -> ComparableManifest: + async def get_manifest(self) -> Manifest: """Get manifest of the main package in the worker.""" @abstractmethod diff --git a/questionpy_server/worker/impl/_base.py b/questionpy_server/worker/impl/_base.py index 38660703..448995e9 100644 --- a/questionpy_server/worker/impl/_base.py +++ b/questionpy_server/worker/impl/_base.py @@ -20,8 +20,13 @@ from questionpy_common.elements import OptionsFormDefinition from questionpy_common.environment import RequestInfo from questionpy_common.manifest import Manifest, PackageFile +from questionpy_common.package_location import ( + DirPackageLocation, + FunctionPackageLocation, + PackageLocation, + ZipPackageLocation, +) from questionpy_server.models import LoadedPackage, QuestionCreated -from questionpy_server.utils.manifest import ComparableManifest from questionpy_server.worker import PackageFileData, Worker, WorkerArgs, WorkerState from questionpy_server.worker.exception import ( StaticFileSizeMismatchError, @@ -47,12 +52,6 @@ ViewAttempt, WorkerError, ) -from questionpy_server.worker.runtime.package_location import ( - DirPackageLocation, - FunctionPackageLocation, - PackageLocation, - ZipPackageLocation, -) if TYPE_CHECKING: from questionpy_server.worker.connection import ServerToWorkerConnection @@ -113,7 +112,7 @@ async def _initialize(self) -> None: async def _load_package(self, package_location: PackageLocation, *, main: bool) -> None: loaded = await self.send_and_wait_for_response( - LoadQPyPackage(location=package_location, main=main), + LoadQPyPackage(location=package_location, main=main, dependencies=self._dependencies), LoadQPyPackage.Response, self.permissions.bootstrap_timeout, ) @@ -226,9 +225,9 @@ async def stop(self, timeout: float) -> None: except TimeoutError: log.info("Worker was killed because it did not stop gracefully") - async def get_manifest(self) -> ComparableManifest: + async def get_manifest(self) -> Manifest: ret = await self.send_and_wait_for_response(GetQPyPackageManifest(), GetQPyPackageManifest.Response) - return ComparableManifest(**ret.manifest.model_dump()) + return ret.manifest async def get_options_form( self, request_info: RequestInfo, question_state: str | None @@ -374,7 +373,7 @@ class LimitTimeUsageMixin(Worker, ABC): the cpu limit in real time. """ - _real_time_limit_factor = 3 + _real_time_limit_factor = 1000 def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/questionpy_server/worker/pool.py b/questionpy_server/worker/pool.py index 37ce09b8..0d066dbe 100644 --- a/questionpy_server/worker/pool.py +++ b/questionpy_server/worker/pool.py @@ -14,15 +14,20 @@ from pydantic import ByteSize +from questionpy_common.dependencies import DependencySolution, SolutionAndLocation, StaticDependencySolution from questionpy_common.environment import PackagePermissions from questionpy_common.error import QPyBaseError -from questionpy_server.worker.impl.subprocess import SubprocessWorker -from questionpy_server.worker.runtime.package_location import ( +from questionpy_common.manifest import Manifest +from questionpy_common.package_location import ( DirPackageLocation, FunctionPackageLocation, PackageLocation, ZipPackageLocation, ) +from questionpy_server.dependencies import DynamicDependencyResolver, resolve_dependency_tree +from questionpy_server.package import Package +from questionpy_server.utils.manifest import read_manifest_from_location +from questionpy_server.worker.impl.subprocess import SubprocessWorker from . import Worker, WorkerState @@ -42,19 +47,39 @@ class _IdleWorkersIdentifier(NamedTuple): context: str +async def _get_location_if_dynamic( + resolver: DynamicDependencyResolver, solution: DependencySolution +) -> SolutionAndLocation: + if isinstance(solution, StaticDependencySolution): + return solution, None + + package = await resolver.get_package_location(solution.hash) + + return solution, package + + class WorkerPool: - def __init__(self, max_workers: int, max_memory: int, worker_type: type[Worker] = SubprocessWorker): + def __init__( + self, + max_workers: int, + max_memory: int, + *, + worker_type: type[Worker] = SubprocessWorker, + dependency_resolver: DynamicDependencyResolver, + ) -> None: """Initialize the worker pool. Args: max_workers (int): maximum number of workers being executed in parallel max_memory (int): maximum memory (in bytes) that all workers in the pool are allowed to consume worker_type (type[Worker]): worker implementation + dependency_resolver: resolver for dynamic dependencies of the package """ self.max_workers = max_workers self.max_memory = max_memory self._worker_type = worker_type + self._dependency_resolver = dependency_resolver self._lock: Lock = Lock() self._semaphore: Semaphore = Semaphore(self.max_workers) @@ -101,7 +126,7 @@ def _memory_available(self, required_memory: int) -> bool: @asynccontextmanager async def get_worker( self, - package: PackageLocation, + package: Package | PackageLocation, user: str | None, context: str, permissions: PackagePermissions, @@ -112,17 +137,25 @@ async def get_worker( A context manager is used to ensure that a worker is always given back to the pool. Args: - package: path to QuestionPy package + package: The main package to be run, either as a [Package][questionpy_server.package.Package] instance, or a + specific [PackageLocation][questionpy_server.worker.runtime.package_location.PackageLocation]. user: the user requesting the worker context: context within the lms permissions: package permissions environment_variables: environment variables to be set in the worker Returns: - A worker + A new or previously idle worker running the given package. """ self._workers_requested += 1 + if isinstance(package, Package): + manifest = package.manifest + package_location: PackageLocation = await package.get_zip_package_location() + else: + manifest = await read_manifest_from_location(package) + package_location = package + # Limit the number of running workers. async with self._semaphore: worker = None @@ -139,7 +172,12 @@ async def get_worker( async with self._lock, self._condition: await self._condition.wait_for(lambda: self._memory_available(permissions.memory)) worker = await self._create_or_reuse_worker( - package, user, context, permissions, environment_variables + package_location=package_location, + manifest=manifest, + user=user, + context=context, + permissions=permissions, + environment_variables=environment_variables, ) self._workers_in_use += 1 @@ -147,7 +185,7 @@ async def get_worker( finally: if worker: async with self._condition: - await self._handle_idle_worker(package, user, context, worker) + await self._handle_idle_worker(package_location, user, context, worker) self._condition.notify() self._workers_in_use -= 1 @@ -209,16 +247,17 @@ def _generate_worker_name(self, package: PackageLocation) -> str: async def _create_or_reuse_worker( self, - package: PackageLocation, + *, + package_location: PackageLocation, + manifest: Manifest, user: str | None, context: str, permissions: PackagePermissions, environment_variables: dict[str, str], ) -> Worker: """If possible, get an idle worker or create a new one.""" - # Since the `PackagePermissions` only dependent on the `user` and `context` the worker - # permissions are the same. - identifier = _IdleWorkersIdentifier(package, user, context) + # Since the `PackagePermissions` only depend on the `user` and `context`, the worker permissions are the same. + identifier = _IdleWorkersIdentifier(package_location, user, context) if identifier in self._idle_workers: # There is an idle worker with this package loaded - reuse the most recent one. worker = self._idle_workers[identifier].popleft() @@ -229,18 +268,28 @@ async def _create_or_reuse_worker( self._memory_idle -= worker.permissions.memory else: + if manifest.dependencies.qpy: + solutions = resolve_dependency_tree(manifest, manifest.dependencies.qpy, self._dependency_resolver) + solutions_and_locations = { + nssn: await _get_location_if_dynamic(self._dependency_resolver, solution) + for nssn, solution in solutions.items() + } + else: + solutions_and_locations = {} + # We need to create a new worker - free as much memory as needed to start the worker. await self._free_memory(permissions.memory) - name = self._generate_worker_name(package) + name = self._generate_worker_name(package_location) worker_home = self._working_dir / f"worker-{name}" await asyncio.to_thread(worker_home.mkdir) worker = self._worker_type( name=name, - package=package, + package=package_location, permissions=permissions, worker_home=worker_home, + dependencies=solutions_and_locations, environment_variables=environment_variables, ) await worker.start() diff --git a/questionpy_server/worker/runtime/manager.py b/questionpy_server/worker/runtime/manager.py index 58b99b21..8e89f1bc 100644 --- a/questionpy_server/worker/runtime/manager.py +++ b/questionpy_server/worker/runtime/manager.py @@ -7,22 +7,24 @@ from contextlib import contextmanager from dataclasses import dataclass from graphlib import TopologicalSorter +from itertools import chain from pathlib import Path from types import MappingProxyType from typing import TYPE_CHECKING, NoReturn, TypeVar, cast -from questionpy_common.constants import MAX_QPY_DEPENDENCY_LEVELS +from questionpy_common import PackageNamespaceAndShortName +from questionpy_common.dependencies import SolutionAndLocation, StaticDependencySolution from questionpy_common.environment import ( Environment, OnRequestCallback, Package, - PackageNamespaceAndShortName, PackagePermissions, PackageState, RequestInfo, set_qpy_environment, ) from questionpy_common.manifest import PackageType +from questionpy_common.package_location import PackageLocation from questionpy_server.worker.runtime.connection import WorkerToServerConnection from questionpy_server.worker.runtime.messages import ( CreateQuestionFromOptions, @@ -40,7 +42,6 @@ WorkerError, ) from questionpy_server.worker.runtime.package import ImportablePackage, NoInitFunctionError, open_qpy_package -from questionpy_server.worker.runtime.package_location import PackageLocation if TYPE_CHECKING: from questionpy_common.api.qtype import QuestionTypeInterface @@ -88,12 +89,14 @@ def register_on_request_callback(self, callback: OnRequestCallback) -> None: type OnMessageCallback[M: MessageToWorker] = Callable[[M], MessageToServer] -def _linearize_packages( - packages: Mapping[PackageNamespaceAndShortName, ImportablePackage], +def _linearize_dependencies( + solutions: Mapping[PackageNamespaceAndShortName, SolutionAndLocation], ) -> Sequence[PackageNamespaceAndShortName]: sorter = TopologicalSorter[PackageNamespaceAndShortName]() - for nssn, package in packages.items(): - sorter.add(nssn, *package.dependencies.keys()) + + for nssn, (solution, _) in solutions.items(): + dep_nssns = [PackageNamespaceAndShortName(dep.namespace, dep.short_name) for dep in solution.dependencies.qpy] + sorter.add(nssn, *dep_nssns) return tuple(sorter.static_order()) @@ -160,68 +163,65 @@ def _open_package(location: PackageLocation, worker_home: Path) -> ImportablePac # This is a separate method to allow it to be mocked separately. return open_qpy_package(location, worker_home) - def _open_packages_recursively( - self, - msg: LoadQPyPackage, - package_location: PackageLocation, - stack: tuple[PackageNamespaceAndShortName, ...] = (), - ) -> tuple[PackageNamespaceAndShortName, ImportablePackage]: - if not self._env or not self._worker_home: - self._raise_not_initialized(msg) - - package = self._open_package(package_location, self._worker_home) - nssn = PackageNamespaceAndShortName(package.manifest.namespace, package.manifest.short_name) - - if nssn in stack and self._packages[nssn].manifest.version == package.manifest.version: - raise CircularDependencyError(nssn, stack) - - if nssn in self._packages: - # For now, we don't support two packages using the same static dependency, even if they would use the same - # version. Supporting the latter case would require us to either trust or check that both dependency's - # content is identical. - err_msg = f"Package '{nssn}' is already loaded. Dependency stack: {stack}" - raise DependencyError(err_msg, stack) + def _init_package(self, nssn: PackageNamespaceAndShortName, env: Environment) -> None: + package = self._packages[nssn] - self._packages[nssn] = package + # Make the package's dependencies accessible to the package. + for dep in package.manifest.dependencies.qpy: + dep_nssn = PackageNamespaceAndShortName(dep.namespace, dep.short_name) + dep_package = self._packages.get(dep_nssn) + if not dep_package: + err_msg = f"Unfulfilled dependency of '{nssn}': '{dep_nssn}'" + raise RuntimeError(err_msg) - new_stack = (*stack, nssn) + package.dependencies[dep_nssn] = dep_package - if len(stack) >= MAX_QPY_DEPENDENCY_LEVELS and package.manifest.dependencies.qpy: - raise TooDeeplyNestedDependencyError(new_stack) + if package.state < PackageState.LOADED: + package.load() - for dep_location in package.resolve_static_dependencies(): - dep_nssn, dep_package = self._open_packages_recursively(msg, dep_location, new_stack) - package.dependencies[dep_nssn] = dep_package + if package.state < PackageState.INITIALIZED: + is_question_like = package.manifest.type in {PackageType.QUESTION, PackageType.QUESTIONTYPE} + try: + package.init(env) + except NoInitFunctionError: + if is_question_like: + # Questions and question types MUST have init functions. (Others MAY.) + raise - return nssn, package + if package is env.main_package and is_question_like: + self._question_type = cast("QuestionTypeInterface", package.interface) def on_msg_load_qpy_package(self, msg: LoadQPyPackage) -> MessageToServer: if not self._env or not self._worker_home: self._raise_not_initialized(msg) - root_nssn, root_package = self._open_packages_recursively(msg, msg.location, ()) + root_package = self._open_package(msg.location, self._worker_home) + root_nssn = root_package.manifest.nssn + self._packages[root_nssn] = root_package + + linearized = _linearize_dependencies(msg.dependencies) + + for nssn in reversed(linearized): + solution, package_location = msg.dependencies[nssn] + if isinstance(solution, StaticDependencySolution): + owner = self._packages.get(solution.owner) + if not owner: + # Since we open packages in reverse topological order, this shouldn't happen. + # (Unless the tree passed to us by the server contains errors.) + err_msg = f"Cannot open static dependency '{nssn}' before owner '{solution.owner}'." + raise RuntimeError(err_msg) + + package_location = owner.resolve_static_dependency(nssn) + + # MyPy doesn't narrow the type properly. + self._packages[nssn] = self._open_package(cast("PackageLocation", package_location), self._worker_home) if msg.main: self._env = dataclasses.replace(self._env, _main_package=root_package) set_qpy_environment(self._env) - linearized = _linearize_packages(self._packages) - for nssn in linearized: - package = self._packages[nssn] - if package.state < PackageState.LOADED: - package.load() - - if package.state < PackageState.INITIALIZED: - is_question_like = package.manifest.type in {PackageType.QUESTION, PackageType.QUESTIONTYPE} - try: - package.init(self._env) - except NoInitFunctionError: - if is_question_like: - # Questions and question types MUST have init functions. (Others MAY.) - raise - - if package is root_package and msg.main and is_question_like: - self._question_type = cast("QuestionTypeInterface", package.interface) + for nssn in chain(linearized, (root_nssn,)): + self._init_package(nssn, self._env) return LoadQPyPackage.Response(root_nssn=root_nssn, loaded_packages=linearized) @@ -329,19 +329,3 @@ class WorkerNotInitializedError(Exception): class MainPackageNotLoadedError(Exception): pass - - -class DependencyError(Exception): - def __init__(self, message: str, stack: tuple[PackageNamespaceAndShortName, ...]) -> None: - super().__init__(message) - self.stack = stack - - -class CircularDependencyError(DependencyError): - def __init__(self, nssn: PackageNamespaceAndShortName, stack: tuple[PackageNamespaceAndShortName, ...]): - super().__init__(f"'{nssn}'. Dependency stack: {stack}", stack) - - -class TooDeeplyNestedDependencyError(DependencyError): - def __init__(self, stack: tuple[PackageNamespaceAndShortName, ...]) -> None: - super().__init__(f"Dependency graph is deeper than '{MAX_QPY_DEPENDENCY_LEVELS}' levels at '{stack}'.", stack) diff --git a/questionpy_server/worker/runtime/messages.py b/questionpy_server/worker/runtime/messages.py index e4447b8b..bfeb6fa7 100644 --- a/questionpy_server/worker/runtime/messages.py +++ b/questionpy_server/worker/runtime/messages.py @@ -10,14 +10,16 @@ from pydantic import BaseModel, JsonValue +from questionpy_common import PackageNamespaceAndShortName from questionpy_common.api.attempt import AttemptModel, AttemptScoredModel, AttemptStartedModel from questionpy_common.api.qtype import InvalidQuestionStateError, OptionsFormValidationError from questionpy_common.api.question import QuestionModel +from questionpy_common.dependencies import SolutionAndLocation from questionpy_common.elements import OptionsFormDefinition -from questionpy_common.environment import PackageNamespaceAndShortName, PackagePermissions, RequestInfo +from questionpy_common.environment import PackagePermissions, RequestInfo from questionpy_common.error import QPyBaseError from questionpy_common.manifest import Manifest -from questionpy_server.worker.runtime.package_location import PackageLocation +from questionpy_common.package_location import PackageLocation messages_header_struct: Struct = Struct("=LL") """4 bytes unsigned long int message id and 4 bytes unsigned long int payload length""" @@ -117,6 +119,9 @@ class LoadQPyPackage(MessageToWorker): main: bool """Set this package as the main package and execute its entry point.""" + dependencies: dict[PackageNamespaceAndShortName, SolutionAndLocation] + """All resolved dependencies in the root package's tree. Does not include the root package itself.""" + class Response(MessageToServer): """Success message in return to LoadQPyPackage.""" diff --git a/questionpy_server/worker/runtime/package.py b/questionpy_server/worker/runtime/package.py index 9e895f3f..f0fd66b7 100644 --- a/questionpy_server/worker/runtime/package.py +++ b/questionpy_server/worker/runtime/package.py @@ -11,18 +11,18 @@ from types import ModuleType from zipfile import ZipFile +from questionpy_common import PackageNamespaceAndShortName from questionpy_common.api.package import QPyPackageInterface from questionpy_common.constants import DIST_DIR, MANIFEST_FILENAME from questionpy_common.environment import ( Environment, Package, - PackageNamespaceAndShortName, PackageNotInitializedError, PackageNotLoadedError, PackageState, ) from questionpy_common.manifest import DistStaticQPyDependency, Manifest -from questionpy_server.worker.runtime.package_location import ( +from questionpy_common.package_location import ( DirPackageLocation, FunctionPackageLocation, PackageLocation, @@ -81,7 +81,7 @@ def init(self, env: Environment) -> None: """ @abstractmethod - def resolve_static_dependencies(self) -> list[PackageLocation]: + def resolve_static_dependency(self, nssn: PackageNamespaceAndShortName) -> PackageLocation: pass @@ -100,12 +100,33 @@ def __repr__(self) -> str: __str__ = __repr__ - def resolve_static_dependencies(self) -> list[PackageLocation]: - return [ - DirPackageLocation(self.path / "dependencies" / "qpy" / dep.dir_name / DIST_DIR) - for dep in self.manifest.dependencies.qpy - if isinstance(dep, DistStaticQPyDependency) - ] + def resolve_static_dependency(self, nssn: PackageNamespaceAndShortName) -> PackageLocation: + dep = next( + ( + dep + for dep in self.manifest.dependencies.qpy + if isinstance(dep, DistStaticQPyDependency) + and dep.namespace == nssn.namespace + and dep.short_name == nssn.short_name + for dep in self.manifest.dependencies.qpy + ), + None, + ) + if not dep: + msg = f"Package '{self.manifest.nssn}' does not provide static dependency '{nssn}'." + raise RuntimeError(msg) + + dep_dist_path = ( + self.path / "dependencies" / "qpy" / f"{dep.namespace}-{dep.short_name}-{dep.version}" / DIST_DIR + ) + if not dep_dist_path.exists(): + msg = ( + f"Package '{self.manifest.nssn}' lists static dependency '{nssn}', but '{dep_dist_path}' is not " + f"present." + ) + raise RuntimeError(msg) + + return DirPackageLocation(dep_dist_path) def load(self) -> None: for new_path in ( @@ -166,8 +187,8 @@ def __repr__(self) -> str: __str__ = __repr__ - def resolve_static_dependencies(self) -> list[PackageLocation]: - return [] + def resolve_static_dependency(self, nssn: PackageNamespaceAndShortName) -> PackageLocation: + raise NotImplementedError def _package_dir(worker_home: Path, manifest: Manifest) -> Path: diff --git a/questionpy_server/worker/selector/__init__.py b/questionpy_server/worker/selector/__init__.py index b634af25..467cc81c 100644 --- a/questionpy_server/worker/selector/__init__.py +++ b/questionpy_server/worker/selector/__init__.py @@ -3,6 +3,8 @@ # (c) Technische Universität Berlin, innoCampus from typing import NamedTuple +import semver + from questionpy_server.package import Package from questionpy_server.settings import PackageSelector, Selectable @@ -18,12 +20,13 @@ def _is_wildcard_matching(selector_value: str, package_value: str) -> bool: def _is_matching(selector: PackageSelector, query: SelectorQuery) -> bool: + manifest = query.package.manifest return ( # Package data. _is_wildcard_matching(selector.hash, query.package.hash) and _is_wildcard_matching(selector.namespace, query.package.manifest.namespace) and _is_wildcard_matching(selector.short_name, query.package.manifest.short_name) - and (selector.version == "*" or query.package.manifest.version.match(selector.version)) + and (selector.version == "*" or semver.Version.parse(manifest.version).match(selector.version)) # Package origin. and _is_wildcard_matching(selector.origin.repositories, "*") # TODO: handle repositories and (selector.origin.local is None or selector.origin.local == query.package.sources.is_local()) diff --git a/tests/conftest.py b/tests/conftest.py index 40e99eaa..ddfe1bc4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ from questionpy_common.constants import DIST_DIR, MANIFEST_FILENAME, MiB from questionpy_common.environment import PackagePermissions from questionpy_common.manifest import PackageFile +from questionpy_common.package_location import DirPackageLocation, ZipPackageLocation +from questionpy_server.dependencies._dynamic_resolver_abc import NoopDependencyResolver from questionpy_server.hash import calculate_hash from questionpy_server.settings import ( AuthSettings, @@ -29,12 +31,11 @@ WebserviceSettings, WorkerPoolSettings, ) -from questionpy_server.utils.manifest import ComparableManifest +from questionpy_server.utils.manifest import Manifest from questionpy_server.web.app import QPyServer from questionpy_server.worker.impl.subprocess import SubprocessWorker from questionpy_server.worker.impl.thread import ThreadWorker from questionpy_server.worker.pool import WorkerPool -from questionpy_server.worker.runtime.package_location import DirPackageLocation, ZipPackageLocation @dataclass(unsafe_hash=True) @@ -45,7 +46,7 @@ def __init__(self, path: Path): super().__init__(path, calculate_hash(path)) with ZipFile(self.path) as package: - self.manifest = ComparableManifest.model_validate_json(package.read(f"{DIST_DIR}/{MANIFEST_FILENAME}")) + self.manifest = Manifest.model_validate_json(package.read(f"{DIST_DIR}/{MANIFEST_FILENAME}")) @dataclass(unsafe_hash=True) @@ -55,7 +56,7 @@ class TestDirPackage(DirPackageLocation): def __init__(self, path: Path) -> None: super().__init__(path) - self.manifest = ComparableManifest.model_validate_json((path / MANIFEST_FILENAME).read_text()) + self.manifest = Manifest.model_validate_json((path / MANIFEST_FILENAME).read_text()) def inject_static_file_into_dist(self, name: str, content: str | bytes) -> int: """Inserts a static file only into dist. Can be used to produce invalid static file configurations.""" @@ -152,5 +153,6 @@ def package_factory(tmp_path_factory: pytest.TempPathFactory) -> TestPackageFact @pytest.fixture(params=(SubprocessWorker, ThreadWorker)) async def worker_pool(request: pytest.FixtureRequest) -> AsyncGenerator[WorkerPool]: - async with WorkerPool(1, 512 * MiB, worker_type=request.param) as pool: + mock_resolver = NoopDependencyResolver() + async with WorkerPool(1, 512 * MiB, worker_type=request.param, dependency_resolver=mock_resolver) as pool: yield pool diff --git a/tests/questionpy_server/collector/test_indexer.py b/tests/questionpy_server/collector/test_indexer.py index 202913da..6f38a02f 100644 --- a/tests/questionpy_server/collector/test_indexer.py +++ b/tests/questionpy_server/collector/test_indexer.py @@ -7,6 +7,7 @@ from unittest.mock import patch import pytest +from semver import Version from questionpy_server import WorkerPool from questionpy_server.collector.abc import BaseCollector @@ -15,7 +16,7 @@ from questionpy_server.collector.local_collector import LocalCollector from questionpy_server.collector.repo_collector import RepoCollector from questionpy_server.package import PackageSources -from questionpy_server.utils.manifest import ComparableManifest +from questionpy_server.utils.manifest import Manifest from tests.conftest import PACKAGE @@ -23,10 +24,10 @@ @patch("questionpy_server.collector.lms_collector.LMSCollector", spec=LMSCollector) async def test_register_package_with_path_and_manifest( collector: LMSCollector, - kind: Path | ComparableManifest, + kind: Path | Manifest, worker_pool: WorkerPool, ) -> None: - indexer = Indexer(worker_pool) + indexer = Indexer() await indexer.register_package(PACKAGE.hash, kind, collector) # Package is accessible by hash. @@ -38,11 +39,13 @@ async def test_register_package_with_path_and_manifest( @patch("questionpy_server.collector.lms_collector.LMSCollector", spec=LMSCollector) async def test_register_package_from_lms(collector: LMSCollector, worker_pool: WorkerPool) -> None: - indexer = Indexer(worker_pool) + indexer = Indexer() await indexer.register_package(PACKAGE.hash, PACKAGE.manifest, collector) # Package is not accessible by identifier and version. - package = indexer.get_by_identifier_and_version(PACKAGE.manifest.identifier, PACKAGE.manifest.version) + package = indexer.get_by_identifier_and_version( + PACKAGE.manifest.identifier, Version.parse(PACKAGE.manifest.version) + ) assert package is None # Package is not accessible by identifier. @@ -61,7 +64,7 @@ async def test_register_package_from_local_and_repo_collector( # Create mock. collector = patch(collector.__module__, spec=collector).start() - indexer = Indexer(worker_pool) + indexer = Indexer() await indexer.register_package(PACKAGE.hash, PACKAGE.manifest, collector) # Package is accessible by hash. @@ -71,14 +74,16 @@ async def test_register_package_from_local_and_repo_collector( assert package.manifest == PACKAGE.manifest # Package is accessible by identifier and version. - new_package = indexer.get_by_identifier_and_version(PACKAGE.manifest.identifier, PACKAGE.manifest.version) + new_package = indexer.get_by_identifier_and_version( + PACKAGE.manifest.identifier, Version.parse(PACKAGE.manifest.version) + ) assert new_package is not None assert new_package is package # Package is accessible by identifier. packages_by_identifier = indexer.get_by_identifier(PACKAGE.manifest.identifier) assert len(packages_by_identifier) == 1 - assert packages_by_identifier[package.manifest.version] is package + assert packages_by_identifier[Version.parse(package.manifest.version)] is package # Package is accessible by retrieving all packages. packages = indexer.get_package_versions_infos() @@ -87,7 +92,7 @@ async def test_register_package_from_local_and_repo_collector( async def test_register_package_with_same_hash_as_existing_package(worker_pool: WorkerPool) -> None: - indexer = Indexer(worker_pool) + indexer = Indexer() # Register package from local collector. local_collector = patch(LocalCollector.__module__, spec=LocalCollector).start() @@ -110,7 +115,7 @@ async def test_register_package_with_same_hash_as_existing_package(worker_pool: # Package will only be listed once. packages_by_identifier = indexer.get_by_identifier(PACKAGE.manifest.identifier) assert len(packages_by_identifier) == 1 - assert packages_by_identifier[package.manifest.version] is package + assert packages_by_identifier[Version.parse(package.manifest.version)] is package packages = indexer.get_package_versions_infos() assert len(packages) == 1 @@ -124,7 +129,7 @@ async def test_register_two_packages_with_same_manifest_but_different_hashes( collector = patch(LocalCollector.__module__, spec=LocalCollector).start() # Register a package. - indexer = Indexer(worker_pool) + indexer = Indexer() await indexer.register_package(PACKAGE.hash, PACKAGE.manifest, collector) with caplog.at_level(logging.WARNING): @@ -139,7 +144,7 @@ async def test_register_two_packages_with_same_manifest_but_different_hashes( async def test_unregister_package_with_lms_source(worker_pool: WorkerPool) -> None: - indexer = Indexer(worker_pool) + indexer = Indexer() collector = patch(LMSCollector.__module__, spec=LMSCollector).start() await indexer.register_package(PACKAGE.hash, PACKAGE.manifest, collector) @@ -152,7 +157,7 @@ async def test_unregister_package_with_lms_source(worker_pool: WorkerPool) -> No @pytest.mark.parametrize("collector", [LocalCollector, RepoCollector]) async def test_unregister_package_with_local_and_repo_source(collector: BaseCollector, worker_pool: WorkerPool) -> None: - indexer = Indexer(worker_pool) + indexer = Indexer() collector = patch(collector.__module__, spec=collector).start() await indexer.register_package(PACKAGE.hash, PACKAGE.manifest, collector) @@ -163,7 +168,9 @@ async def test_unregister_package_with_local_and_repo_source(collector: BaseColl assert package is None # Package is not accessible by identifier and version. - package = indexer.get_by_identifier_and_version(PACKAGE.manifest.identifier, PACKAGE.manifest.version) + package = indexer.get_by_identifier_and_version( + PACKAGE.manifest.identifier, Version.parse(PACKAGE.manifest.version) + ) assert package is None # Package is not accessible by identifier. @@ -172,7 +179,7 @@ async def test_unregister_package_with_local_and_repo_source(collector: BaseColl async def test_unregister_package_with_multiple_sources(worker_pool: WorkerPool) -> None: - indexer = Indexer(worker_pool) + indexer = Indexer() # Register package from local, repo, and LMS collector. lms_collector = patch(LMSCollector.__module__, spec=LMSCollector).start() @@ -192,7 +199,9 @@ async def test_unregister_package_with_multiple_sources(worker_pool: WorkerPool) assert package is not None # Package is still accessible by identifier and version. - package = indexer.get_by_identifier_and_version(PACKAGE.manifest.identifier, PACKAGE.manifest.version) + package = indexer.get_by_identifier_and_version( + PACKAGE.manifest.identifier, Version.parse(PACKAGE.manifest.version) + ) assert package is not None # Package is still accessible by identifier. @@ -207,7 +216,9 @@ async def test_unregister_package_with_multiple_sources(worker_pool: WorkerPool) assert package is not None # Package is not accessible by identifier and version. - package = indexer.get_by_identifier_and_version(PACKAGE.manifest.identifier, PACKAGE.manifest.version) + package = indexer.get_by_identifier_and_version( + PACKAGE.manifest.identifier, Version.parse(PACKAGE.manifest.version) + ) assert package is None # Package is not accessible by identifier. diff --git a/tests/questionpy_server/collector/test_lms_collector.py b/tests/questionpy_server/collector/test_lms_collector.py index eeef03ec..acdf7f51 100644 --- a/tests/questionpy_server/collector/test_lms_collector.py +++ b/tests/questionpy_server/collector/test_lms_collector.py @@ -26,7 +26,7 @@ def create_lms_collector(tmp_path_factory: TempPathFactory, worker_pool: WorkerP cache_path.mkdir() cache = LRUCache(supervisor_cache, cache_path, extension=".qpy") - indexer = Indexer(worker_pool) + indexer = Indexer() return LMSCollector(cache, indexer), cache diff --git a/tests/questionpy_server/collector/test_local_collector.py b/tests/questionpy_server/collector/test_local_collector.py index 8b3bb4e3..454af7cd 100644 --- a/tests/questionpy_server/collector/test_local_collector.py +++ b/tests/questionpy_server/collector/test_local_collector.py @@ -25,7 +25,7 @@ def create_local_collector(tmp_path_factory: TempPathFactory, worker_pool: WorkerPool) -> tuple[LocalCollector, Path]: """Create and return a local collector along with the directory it is using.""" path = tmp_path_factory.mktemp("qpy") - indexer = Indexer(worker_pool) + indexer = Indexer() return LocalCollector(path, indexer), path @@ -81,7 +81,7 @@ async def test_ignore_files_with_wrong_extension(tmp_path_factory: TempPathFacto ignore_file = directory / "wrong.extension" ignore_file.touch() - indexer = Indexer(worker_pool) + indexer = Indexer() local_collector = LocalCollector(directory, indexer) async with local_collector: @@ -97,7 +97,7 @@ async def test_ignore_files_with_wrong_extension(tmp_path_factory: TempPathFacto async def test_package_exists_before_init(tmp_path_factory: TempPathFactory, worker_pool: WorkerPool) -> None: path = tmp_path_factory.mktemp("qpy") - indexer = Indexer(worker_pool) + indexer = Indexer() local_collector = LocalCollector(path, indexer) package_path = copy(PACKAGE.path, path) @@ -254,7 +254,7 @@ async def test_package_gets_moved_to_different_folder( # Use new_directory as the directory to be watched and directory to be the new directory of the package. directory, new_directory = new_directory, directory - indexer = Indexer(worker_pool) + indexer = Indexer() local_collector = LocalCollector(directory, indexer) # Create a package in the directory. diff --git a/tests/questionpy_server/collector/test_package_collection.py b/tests/questionpy_server/collector/test_package_collection.py index e1257999..a9a00ffe 100644 --- a/tests/questionpy_server/collector/test_package_collection.py +++ b/tests/questionpy_server/collector/test_package_collection.py @@ -17,7 +17,7 @@ async def test_start() -> None: - package_collection = PackageCollection(Path("test_dir/"), {}, Mock(), Mock(), Mock()) + package_collection = PackageCollection(Path("test_dir/"), {}, Mock(), Mock()) with patch.object(LMSCollector, "start") as lms_start, patch.object(LocalCollector, "start") as local_start: await package_collection.start() @@ -26,7 +26,7 @@ async def test_start() -> None: async def test_stop() -> None: - package_collection = PackageCollection(Path("test_dir/"), {}, Mock(), Mock(), Mock()) + package_collection = PackageCollection(Path("test_dir/"), {}, Mock(), Mock()) with patch.object(LMSCollector, "stop") as lms_stop, patch.object(LocalCollector, "stop") as local_stop: await package_collection.stop() @@ -35,7 +35,7 @@ async def test_stop() -> None: async def test_put_package() -> None: - package_collection = PackageCollection(None, {}, Mock(), Mock(), Mock()) + package_collection = PackageCollection(None, {}, Mock(), Mock()) with patch.object(LMSCollector, "put") as put: await package_collection.put(HashContainer(b"", "hash")) @@ -43,7 +43,7 @@ async def test_put_package() -> None: def test_get_package() -> None: - package_collection = PackageCollection(None, {}, Mock(), Mock(), Mock()) + package_collection = PackageCollection(None, {}, Mock(), Mock()) # Package does exist. with patch.object(Indexer, "get_by_hash") as get_by_hash: @@ -57,7 +57,7 @@ def test_get_package() -> None: def test_get_package_by_identifier() -> None: - package_collection = PackageCollection(None, {}, Mock(), Mock(), Mock()) + package_collection = PackageCollection(None, {}, Mock(), Mock()) with patch.object(Indexer, "get_by_identifier") as get_by_identifier: package_collection.get_by_identifier("@default/name") @@ -65,7 +65,7 @@ def test_get_package_by_identifier() -> None: def test_get_package_by_identifier_and_version() -> None: - package_collection = PackageCollection(None, {}, Mock(), Mock(), Mock()) + package_collection = PackageCollection(None, {}, Mock(), Mock()) # Package does exist. with patch.object(Indexer, "get_by_identifier_and_version") as get_by_identifier_and_version: @@ -81,7 +81,7 @@ def test_get_package_by_identifier_and_version() -> None: def test_get_packages() -> None: - package_collection = PackageCollection(None, {}, Mock(), Mock(), Mock()) + package_collection = PackageCollection(None, {}, Mock(), Mock()) # Package does exist. with patch.object(Indexer, "get_package_versions_infos") as get_package_versions_infos: @@ -96,7 +96,7 @@ async def test_notify_indexer_on_cache_deletion(tmp_path_factory: TempPathFactor cache = LRUCache(supervisor, cache_path, extension=".qpy") await cache.put("hash", b"") - PackageCollection(None, {}, Mock(), cache, Mock()) + PackageCollection(None, {}, Mock(), cache) # The callback should unregister the package from the indexer. with patch.object(Indexer, "unregister_package") as unregister_package: diff --git a/tests/questionpy_server/repository/test_repository.py b/tests/questionpy_server/repository/test_repository.py index de29f7f9..f9aeaf9c 100644 --- a/tests/questionpy_server/repository/test_repository.py +++ b/tests/questionpy_server/repository/test_repository.py @@ -14,7 +14,7 @@ from questionpy_common.constants import KiB from questionpy_server.cache import CacheItemTooLargeError, LRUCache, LRUCacheSupervisor from questionpy_server.repository import RepoMeta, RepoPackage, RepoPackageIndex, Repository -from questionpy_server.utils.manifest import ComparableManifest +from questionpy_server.utils.manifest import Manifest from tests.test_data.factories import ManifestFactory, RepoMetaFactory, RepoPackageVersionsFactory REPO_URL = "https://example.com/repo/" @@ -88,7 +88,7 @@ async def test_get_packages(tmp_path_factory: TempPathFactory) -> None: expected_manifest["api_version"] = versions.api_version # Check if the combined manifest is correct. - assert package.manifest == ComparableManifest(**expected_manifest) + assert package.manifest == Manifest(**expected_manifest) async def test_get_packages_cached(tmp_path_factory: TempPathFactory) -> None: diff --git a/tests/questionpy_server/web/routes/test_packages.py b/tests/questionpy_server/web/routes/test_packages.py index d1cdc062..b0565ddc 100644 --- a/tests/questionpy_server/web/routes/test_packages.py +++ b/tests/questionpy_server/web/routes/test_packages.py @@ -13,7 +13,7 @@ from questionpy_server.collector.local_collector import LocalCollector from questionpy_server.models import PackageVersionInfo, PackageVersionsInfo, RequestErrorCode -from questionpy_server.utils.manifest import ComparableManifest +from questionpy_server.utils.manifest import Manifest from questionpy_server.web.app import QPyServer from tests.conftest import PACKAGE from tests.test_data.factories import ManifestFactory @@ -35,11 +35,11 @@ ], ) async def test_packages(qpy_server: QPyServer, aiohttp_client: AiohttpClient, packages: dict[str, set[str]]) -> None: - async def add_package_version(server: QPyServer, manifest: ComparableManifest) -> None: - package_hash = sha256((manifest.short_name + manifest.namespace + str(manifest.version)).encode()).hexdigest() + async def add_package_version(server: QPyServer, manifest: Manifest) -> None: + package_hash = sha256((manifest.short_name + manifest.namespace + manifest.version).encode()).hexdigest() await server.package_collection._indexer.register_package(package_hash, manifest, Mock(spec=LocalCollector)) - manifests: dict[str, dict[str, ComparableManifest]] = {} + manifests: dict[str, dict[str, Manifest]] = {} for namespace, versions in packages.items(): for version in versions: expected_manifest = ManifestFactory.build(namespace=namespace, short_name=namespace, version=version) diff --git a/tests/questionpy_server/worker/impl/test_base.py b/tests/questionpy_server/worker/impl/test_base.py index 529a489f..b6ca543d 100644 --- a/tests/questionpy_server/worker/impl/test_base.py +++ b/tests/questionpy_server/worker/impl/test_base.py @@ -19,7 +19,7 @@ from tests.questionpy_server.worker.impl.conftest import patch_worker_pool if TYPE_CHECKING: - from questionpy_server.worker.runtime.package_location import PackageLocation + from questionpy_common.package_location import PackageLocation async def test_should_get_manifest(worker_pool: WorkerPool) -> None: diff --git a/tests/test_data/factories.py b/tests/test_data/factories.py index 21ba16c1..cda2cf32 100644 --- a/tests/test_data/factories.py +++ b/tests/test_data/factories.py @@ -9,9 +9,9 @@ from polyfactory.factories.pydantic_factory import ModelFactory from semver import Version -from questionpy_common.manifest import Bcp47LanguageTag, PartialPackagePermissions +from questionpy_common.manifest import Bcp47LanguageTag, DistDependencies, PartialPackagePermissions from questionpy_server.repository.models import RepoMeta, RepoPackageVersions -from questionpy_server.utils.manifest import ComparableManifest +from questionpy_server.utils.manifest import Manifest class CustomFactory(ModelFactory[Any]): @@ -31,9 +31,11 @@ class RepoMetaFactory(ModelFactory): class RepoPackageVersionsFactory(CustomFactory): __model__ = RepoPackageVersions + manifest = Use(lambda: ManifestFactory.build()) # noqa: PLW0108 + class ManifestFactory(CustomFactory): - __model__ = ComparableManifest + __model__ = Manifest short_name = Use(lambda: ModelFactory.__faker__.word().lower() + "_sn") namespace = Use(lambda: ModelFactory.__faker__.word().lower() + "_ns") @@ -42,3 +44,4 @@ class ManifestFactory(CustomFactory): url = Use(ModelFactory.__faker__.url) icon = None permissions = PartialPackagePermissions() + dependencies = DistDependencies(qpy=[])