diff --git a/flake.lock b/flake.lock index 394d1d3a..d7bbdc91 100644 --- a/flake.lock +++ b/flake.lock @@ -36,11 +36,11 @@ }, "nixpkgs-unstable": { "locked": { - "lastModified": 1772624091, - "narHash": "sha256-QKyJ0QGWBn6r0invrMAK8dmJoBYWoOWy7lN+UHzW1jc=", + "lastModified": 1774709303, + "narHash": "sha256-D3Q07BbIA2KnTcSXIqqu9P586uWxN74zNoCH3h2ESHg=", "owner": "nixos", "repo": "nixpkgs", - "rev": "80bdc1e5ce51f56b19791b52b2901187931f5353", + "rev": "8110df5ad7abf5d4c0f6fb0f8f978390e77f9685", "type": "github" }, "original": { @@ -63,11 +63,11 @@ ] }, "locked": { - "lastModified": 1772555609, - "narHash": "sha256-3BA3HnUvJSbHJAlJj6XSy0Jmu7RyP2gyB/0fL7XuEDo=", + "lastModified": 1773870109, + "narHash": "sha256-ZoTdqZP03DcdoyxvpFHCAek4bkPUTUPUF3oCCgc3dP4=", "owner": "pyproject-nix", "repo": "build-system-pkgs", - "rev": "c37f66a953535c394244888598947679af231863", + "rev": "b6e74f433b02fa4b8a7965ee24680f4867e2926f", "type": "github" }, "original": { @@ -83,11 +83,11 @@ ] }, "locked": { - "lastModified": 1771518446, - "narHash": "sha256-nFJSfD89vWTu92KyuJWDoTQJuoDuddkJV3TlOl1cOic=", + "lastModified": 1774498001, + "narHash": "sha256-wTfdyzzrmpuqt4TQQNqilF91v0m5Mh1stNy9h7a/WK4=", "owner": "nix-community", "repo": "pyproject.nix", - "rev": "eb204c6b3335698dec6c7fc1da0ebc3c6df05937", + "rev": "794afa6eb588b498344f2eaa36ab1ceb7e6b0b09", "type": "github" }, "original": { @@ -131,11 +131,11 @@ ] }, "locked": { - "lastModified": 1772545244, - "narHash": "sha256-Ys+5UMOqp2kRvnSjyBcvGnjOhkIXB88On1ZcAstz1vY=", + "lastModified": 1774929536, + "narHash": "sha256-dMTjy8hu4XFAdNHdcLtCryN3SHqSUFHHqDLep+3b2v4=", "owner": "pyproject-nix", "repo": "uv2nix", - "rev": "482aba340ded40ef557d331315f227d5eba84ced", + "rev": "5d0e883867b1cf53263fcf1bfd34542d40abf5a9", "type": "github" }, "original": { diff --git a/pyproject.toml b/pyproject.toml index a41595e1..edf2d995 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ debug = [ test = [ "pyhamcrest>=2.1.0", "pytest>=8.3.3", + "pytest-asyncio>=1.3.0", "pytest-cov>=5.0.0", "pytest-timeout>=2.4.0", "pytest-repeat>=0.9.4", diff --git a/src/dvsim/instrumentation/metadata.py b/src/dvsim/instrumentation/metadata.py index 7871db1a..f674969f 100644 --- a/src/dvsim/instrumentation/metadata.py +++ b/src/dvsim/instrumentation/metadata.py @@ -29,6 +29,7 @@ class MetadataJobFragment(JobFragment): job_type: str target: str tool: str + backend: str | None dependencies: list[str] status: str @@ -61,6 +62,7 @@ def build_report_fragments(self) -> InstrumentationFragments | None: spec.job_type, spec.target, spec.tool.name, + spec.backend, spec.dependencies, status_str, ) diff --git a/src/dvsim/instrumentation/resources.py b/src/dvsim/instrumentation/resources.py index d02ff82c..8b9505cb 100644 --- a/src/dvsim/instrumentation/resources.py +++ b/src/dvsim/instrumentation/resources.py @@ -227,7 +227,7 @@ def on_job_status_change(self, job: JobSpec, status: JobStatus) -> None: with self._lock: running = job_id in self._running_jobs started = running or job_id in self._finished_jobs - if not started and status != JobStatus.QUEUED: + if not started and status not in (JobStatus.SCHEDULED, JobStatus.QUEUED): self._running_jobs[job_id] = JobResourceAggregate(job) running = True if running and status.is_terminal: diff --git a/src/dvsim/instrumentation/timing.py b/src/dvsim/instrumentation/timing.py index d0c1192b..9766c34a 100644 --- a/src/dvsim/instrumentation/timing.py +++ b/src/dvsim/instrumentation/timing.py @@ -99,7 +99,7 @@ def on_job_status_change(self, job: JobSpec, status: JobStatus) -> None: job_info = TimingJobFragment(job) self._jobs[job_id] = job_info - if job_info.start_time is None and status != JobStatus.QUEUED: + if job_info.start_time is None and status not in (JobStatus.SCHEDULED, JobStatus.QUEUED): job_info.start_time = time.perf_counter() if status.is_terminal: job_info.end_time = time.perf_counter() diff --git a/src/dvsim/job/data.py b/src/dvsim/job/data.py index d820738c..39d9a936 100644 --- a/src/dvsim/job/data.py +++ b/src/dvsim/job/data.py @@ -54,6 +54,11 @@ class JobSpec(BaseModel): target: str """run phase [build, run, ...]""" + backend: str | None + """The runtime backend to execute this job with. If not provided (None), this + indicates that whatever is configured as the 'default' backend should be used. + """ + seed: int | None """Seed if there is one.""" diff --git a/src/dvsim/job/deploy.py b/src/dvsim/job/deploy.py index 71247b85..9bfd409d 100644 --- a/src/dvsim/job/deploy.py +++ b/src/dvsim/job/deploy.py @@ -110,6 +110,9 @@ def get_job_spec(self) -> "JobSpec": name=self.name, job_type=self.__class__.__name__, target=self.target, + # TODO: for now we always use the default configured backend, but it might be good + # to allow different jobs to run on different backends in the future? + backend=None, seed=getattr(self, "seed", None), full_name=self.full_name, qual_name=self.qual_name, diff --git a/src/dvsim/job/status.py b/src/dvsim/job/status.py index e409a155..457f3bc2 100644 --- a/src/dvsim/job/status.py +++ b/src/dvsim/job/status.py @@ -12,11 +12,14 @@ class JobStatus(Enum): """Status of a Job.""" - QUEUED = auto() - RUNNING = auto() - PASSED = auto() - FAILED = auto() - KILLED = auto() + # SCHEDULED is currently unused in the old sync scheduler, there `SCHEDULED` and `QUEUED` + # are combined under `QUEUED`. It is used only in the new async scheduler. + SCHEDULED = auto() # Waiting for dependencies + QUEUED = auto() # Dependencies satisfied, waiting to be dispatched + RUNNING = auto() # Dispatched to a backend and actively executing + PASSED = auto() # Completed successfully + FAILED = auto() # Completed with failure + KILLED = auto() # Forcibly terminated or never executed @property def shorthand(self) -> str: diff --git a/src/dvsim/scheduler/async_core.py b/src/dvsim/scheduler/async_core.py new file mode 100644 index 00000000..651ad983 --- /dev/null +++ b/src/dvsim/scheduler/async_core.py @@ -0,0 +1,582 @@ +# Copyright lowRISC contributors (OpenTitan project). +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +"""Job scheduler.""" + +import asyncio +import heapq +from collections import defaultdict +from collections.abc import Callable, Iterable, Mapping, Sequence +from dataclasses import dataclass, field +from signal import SIGINT, SIGTERM, getsignal, signal +from types import FrameType +from typing import Any, TypeAlias + +from dvsim.job.data import CompletedJobStatus, JobSpec, JobStatusInfo +from dvsim.job.status import JobStatus +from dvsim.logging import log +from dvsim.runtime.backend import RuntimeBackend +from dvsim.runtime.data import JobCompletionEvent, JobHandle + +__all__ = ( + "JobPriorityFn", + "JobRecord", + "OnJobStatusChangeCb", + "OnRunEndCb", + "OnRunStartCb", + "OnSchedulerKillCb", + "Priority", + "Scheduler", +) + + +@dataclass +class JobRecord: + """Mutable runtime representation of a scheduled job, used in the scheduler.""" + + spec: JobSpec + backend_key: str # either spec.backend, or the default backend if not given + + status: JobStatus = JobStatus.SCHEDULED + status_info: JobStatusInfo | None = None + + remaining_deps: int = 0 + passing_deps: int = 0 + dependents: list[str] = field(default_factory=list) + kill_requested: bool = False + + handle: JobHandle | None = None + + +# Function to assign a priority to a given job specification. The returned priority should be +# some lexicographically orderable type. Jobs with higher priority are scheduled first. +Priority: TypeAlias = int | float | Sequence[int | float] +JobPriorityFn: TypeAlias = Callable[[JobRecord], Priority] + +# Callbacks for observers, for when the scheduler run starts and stops +OnRunStartCb: TypeAlias = Callable[[], None] +OnRunEndCb: TypeAlias = Callable[[], None] + +# Callbacks for observers, for when a job status changes in the scheduler +# The arguments are: (job spec, old status, new status). +OnJobStatusChangeCb: TypeAlias = Callable[[JobSpec, JobStatus, JobStatus], None] + +# Callbacks for observers, for when the scheduler receives a kill signal (termination). +OnSchedulerKillCb: TypeAlias = Callable[[], None] + + +# Standard context messages used for killed/failed jobs in the scheduler. +FAILED_DEP = JobStatusInfo( + message="Job cancelled because one of its dependencies failed or was killed." +) +ALL_FAILED_DEP = JobStatusInfo( + message="Job cancelled because all of its dependencies failed or were killed." +) +KILLED_SCHEDULED = JobStatusInfo( + message="Job cancelled because one of its dependencies was killed." +) +KILLED_QUEUED = JobStatusInfo(message="Job killed whilst waiting to begin execution.") +KILLED_RUNNING_SIGINT = JobStatusInfo( + message="Job killed by a SIGINT signal to the scheduler whilst executing." +) +KILLED_RUNNING_SIGTERM = JobStatusInfo( + message="Job killed by a SIGTERM signal to the scheduler whilst executing." +) + + +class Scheduler: + """Event-driven job scheduler that schedules and runs a DAG of job specifications.""" + + def __init__( # noqa: PLR0913 + self, + jobs: Iterable[JobSpec], + backends: Mapping[str, RuntimeBackend], + default_backend: str, + *, + max_parallelism: int = 0, + priority_fn: JobPriorityFn | None = None, + coalesce_window: float | None = 0.001, + ) -> None: + """Construct a new scheduler to run a DAG of jobs. + + Args: + jobs: The DAG of jobs to run. A sequence of job specifications, where the DAG is + defined by the job IDs and job dependency lists. + backends: The mapping (name -> backend) of backends available to the scheduler. + default_backend: The name of the default backend to use if not specified by a job. + max_parallelism: The maximum number of jobs that the scheduler is allowed to dispatch + at once, across all backends. The default value of `0` indicates no upper limit. + priority_fn: A function to calculate the priority of a given job. If no function is + given, this defaults to using the job's weight. + coalesce_window: If specified, the time in seconds to wait on receiving a job + completion, to give a short amount of time to allow other batched completion events + to arrive in the queue. This lets us batch scheduling more frequently for a little + extra cost. Defaults to 1 millisecond, and can be disabled by giving `None`. + + """ + if max_parallelism < 0: + err = f"max_parallelism must be some non-negative integer, not {max_parallelism}" + raise ValueError(err) + if default_backend not in backends: + err = f"Default backend '{default_backend}' is not in the mapping of given backends" + raise ValueError(err) + if coalesce_window is not None and coalesce_window < 0.0: + raise ValueError("coalesce_window must be None or some non-negative number") + + # Configuration of the scheduler's behaviour + self.backends = dict(backends) + self.default_backend = default_backend + self.max_parallelism = max_parallelism + self.priority_fn = priority_fn or self._default_priority + self.coalesce_window = coalesce_window + + # Internal data structures and indexes to track running jobs. + self._jobs: dict[str, JobRecord] = {} + self._ready_heap: list[tuple[Priority, str]] = [] + self._running: set[str] = set() + self._running_per_backend: dict[str, int] = dict.fromkeys(backends, 0) + self._event_queue: asyncio.Queue[Iterable[JobCompletionEvent]] = asyncio.Queue() + + # Internal flags and signal handling + self._shutdown_signal: int | None = None + self._shutdown_event: asyncio.Event | None = None + self._original_sigint_handler: Any = None + self._shutdown_started = False + + # Registered callbacks from observers + self._on_run_start: list[OnRunStartCb] = [] + self._on_run_end: list[OnRunEndCb] = [] + self._on_job_status_change: list[OnJobStatusChangeCb] = [] + self._on_kill_signal: list[OnSchedulerKillCb] = [] + + self._build_graph(jobs) + + def add_run_start_callback(self, cb: OnRunStartCb) -> None: + """Register an observer to notify when the scheduler run is started.""" + self._on_run_start.append(cb) + + def add_run_end_callback(self, cb: OnRunEndCb) -> None: + """Register an observer to notify when the scheduler run ends.""" + self._on_run_end.append(cb) + + def add_job_status_change_callback(self, cb: OnJobStatusChangeCb) -> None: + """Register an observer to notify when the status of a job in the scheduler changes.""" + self._on_job_status_change.append(cb) + + def add_kill_signal_callback(self, cb: OnSchedulerKillCb) -> None: + """Register an observer to notify when the scheduler is killed by some signal.""" + self._on_kill_signal.append(cb) + + def _default_priority(self, job: JobRecord) -> Priority: + """Prioritizes jobs according to their weight. The default prioritization method.""" + return job.spec.weight + + def _build_graph(self, specs: Iterable[JobSpec]) -> None: + """Build the job dependency graph and validate the DAG structure.""" + # Build an index of runtime job records, and check for duplicates + for spec in specs: + if spec.id in self._jobs: + log.warning("Duplicate job ID '%s'", spec.id) + # TODO: when we're sure it's ok, change the behaviour to error on duplicate jobs + # : err = f"Duplicate job ID '{spec.id}'" + # : raise ValueError(err) + # Instead, silently ignore it for now to match the original scheduler behaviour + continue + if spec.backend is not None and spec.backend not in self.backends: + err = f"Unknown job backend '{spec.backend}'" + raise ValueError(err) + backend_name = self.default_backend if spec.backend is None else spec.backend + self._jobs[spec.id] = JobRecord(spec=spec, backend_key=backend_name) + + # Build a graph from the adjacency list formed by the spec dependencies + for job in self._jobs.values(): + job.remaining_deps = len(job.spec.dependencies) + for dep in job.spec.dependencies: + if dep not in self._jobs: + err = f"Unknown job dependency '{dep}' for job {job.spec.id}" + raise ValueError(err) + self._jobs[dep].dependents.append(job.spec.id) + + # Validate that there are no cycles in the given graph. + self._validate_acyclic() + + def _validate_acyclic(self) -> None: + """Validate that the given job digraph is acyclic via Kahn's Algorithm.""" + indegree = {job: record.remaining_deps for job, record in self._jobs.items()} + job_queue = [job for job, degree in indegree.items() if degree == 0] + num_visited = 0 + + while job_queue: + job = job_queue.pop() + num_visited += 1 + for dep in self._jobs[job].dependents: + indegree[dep] -= 1 + if indegree[dep] == 0: + job_queue.append(dep) + + if num_visited != len(self._jobs): + raise ValueError("The given JobSpec graph contains a dependency cycle.") + + def _notify_run_started(self) -> None: + """Notify any observers that the scheduler run has started.""" + for cb in self._on_run_start: + cb() + + def _notify_run_finished(self) -> None: + """Notify any observers that the scheduler run has finished.""" + for cb in self._on_run_end: + cb() + + def _notify_kill_signal(self) -> None: + """Notify any observers that the scheduler received a kill signal.""" + for cb in self._on_kill_signal: + cb() + + def _change_job_status( + self, job: JobRecord, new_status: JobStatus, info: JobStatusInfo | None = None + ) -> JobStatus: + """Change a job's runtime status, storing an optionally associated reason. + + Notifies any status change observers of the change, and returns the previous status. + """ + old_status = job.status + if old_status == new_status: + return old_status + + job.status = new_status + job.status_info = info + + if new_status != JobStatus.RUNNING: + log.log( + log.ERROR if new_status in (JobStatus.FAILED, JobStatus.KILLED) else log.VERBOSE, + "Status change to [%s: %s] for %s", + new_status.shorthand, + new_status.name.capitalize(), + job.spec.full_name, + ) + + for cb in self._on_job_status_change: + cb(job.spec, old_status, new_status) + + return old_status + + def _mark_job_ready(self, job: JobRecord) -> None: + """Mark a given job in the scheduler as ready to execute (all dependencies completed).""" + if job.status != JobStatus.SCHEDULED: + msg = f"_mark_job_ready only applies to 'SCHEDULED' jobs (not '{job.status.name}')." + raise RuntimeError(msg) + + self._change_job_status(job, JobStatus.QUEUED) + # heapq is a min heap, so push (-priority) instead of (priority). + priority = self.priority_fn(job) + priority = priority if isinstance(priority, Sequence) else (priority,) + neg_priority: Priority = tuple(-x for x in priority) + heapq.heappush(self._ready_heap, (neg_priority, job.spec.id)) + + def _mark_job_running(self, job: JobRecord) -> None: + """Mark a given job in the scheduler as running. Assumes already removed from the heap.""" + if job.spec.id in self._running: + raise RuntimeError("_mark_job_running called on a job that was already running.") + + self._change_job_status(job, JobStatus.RUNNING) + self._running.add(job.spec.id) + self._running_per_backend[job.backend_key] += 1 + + def _mark_job_completed( + self, job: JobRecord, status: JobStatus, reason: JobStatusInfo | None + ) -> None: + """Mark a given job in the scheduler as completed, having reached some terminal state.""" + if not status.is_terminal: + err = f"_mark_job_completed called with non-terminal status '{status.name}'" + raise RuntimeError(err) + if job.status.is_terminal: + return + + # If the scheduler requested to kill the job, override the failure reason. + if job.kill_requested: + reason = ( + KILLED_RUNNING_SIGINT if self._shutdown_signal == SIGINT else KILLED_RUNNING_SIGTERM + ) + self._change_job_status(job, status, reason) + + # If the job was running, mark it as no longer running. + if job.spec.id in self._running: + self._running.remove(job.spec.id) + self._running_per_backend[job.backend_key] -= 1 + + # Update dependents (jobs that depend on this job), propagating failures if needed. + self._update_completed_job_deps(job) + + def _update_completed_job_deps(self, job: JobRecord) -> None: + """Update the dependencies of a completed job, scheduling/killing deps where necessary.""" + for dep_id in job.dependents: + dep = self._jobs[dep_id] + + # Update dependency tracking counts in the dependency records + dep.remaining_deps -= 1 + if job.status == JobStatus.PASSED: + dep.passing_deps += 1 + + # Propagate kill signals on shutdown + if self._shutdown_signal is not None: + self._mark_job_completed(dep, JobStatus.KILLED, KILLED_SCHEDULED) + continue + + # Handle dependency management and marking dependents as ready + if dep.remaining_deps == 0 and dep.status == JobStatus.SCHEDULED: + if dep.spec.needs_all_dependencies_passing: + if dep.passing_deps == len(dep.spec.dependencies): + self._mark_job_ready(dep) + else: + self._mark_job_completed(dep, JobStatus.KILLED, FAILED_DEP) + elif dep.passing_deps > 0: + self._mark_job_ready(dep) + else: + self._mark_job_completed(dep, JobStatus.KILLED, ALL_FAILED_DEP) + + async def run(self) -> list[CompletedJobStatus]: + """Run all scheduled jobs to completion (unless terminated) and return the results.""" + self._install_signal_handlers() + + for backend in self.backends.values(): + backend.attach_completion_callback(self._submit_job_completion) + + self._notify_run_started() + + # Before entering the main loop, mark jobs with 0 remaining deps as ready to run. + for job in self._jobs.values(): + if job.remaining_deps == 0: + self._mark_job_ready(job) + + try: + await self._main_loop() + finally: + self._notify_run_finished() + + return [ + CompletedJobStatus( + name=job.spec.name, + job_type=job.spec.job_type, + seed=job.spec.seed, + block=job.spec.block, + tool=job.spec.tool, + workspace_cfg=job.spec.workspace_cfg, + full_name=job.spec.full_name, + qual_name=job.spec.qual_name, + target=job.spec.target, + log_path=job.spec.log_path, + job_runtime=job.handle.job_runtime.with_unit("s").get()[0] + if job.handle is not None + else 0.0, + simulated_time=job.handle.simulated_time.with_unit("us").get()[0] + if job.handle is not None + else 0.0, + status=job.status, + fail_msg=job.status_info, + ) + for job in self._jobs.values() + ] + + def _install_signal_handlers(self) -> None: + """Install the SIGINT/SIGTERM signal handlers to trigger graceful shutdowns.""" + self._shutdown_signal = None + self._shutdown_event = asyncio.Event() + self._original_sigint_handler = getsignal(SIGINT) + self._shutdown_started = False + loop = asyncio.get_running_loop() + + def _handler(signum: int, _frame: FrameType | None) -> None: + if self._shutdown_signal is None and self._shutdown_event: + self._shutdown_signal = signum + loop.call_soon_threadsafe(self._shutdown_event.set) + + # Restore the original SIGINT handler so a second Ctrl-C terminates immediately + if signum == SIGINT: + signal(SIGINT, self._original_sigint_handler) + + loop.add_signal_handler(SIGINT, lambda: _handler(SIGINT, None)) + loop.add_signal_handler(SIGTERM, lambda: _handler(SIGTERM, None)) + + async def _submit_job_completion(self, events: Iterable[JobCompletionEvent]) -> None: + """Notify the scheduler that a batch of jobs have been completed.""" + try: + self._event_queue.put_nowait(events) + except asyncio.QueueShutDown as e: + msg = "Scheduler event queue shutdown earlier than expected?" + raise RuntimeError(msg) from e + except asyncio.QueueFull: + log.critical("Scheduler event queue full despite being infinitely sized?") + + async def _main_loop(self) -> None: + """Run the main scheduler loop. + + Tries to schedule any ready jobs if there is available capacity, and then waits for any job + completions (or a shutdown signal). This continues in a loop until all jobs have been either + executed or killed (e.g. via a shutdown signal). + """ + if self._shutdown_event is None: + raise RuntimeError("Expected signal handlers to be installed before running main loop") + + job_completion_task = asyncio.create_task(self._event_queue.get()) + shutdown_task = asyncio.create_task(self._shutdown_event.wait()) + + try: + while True: + await self._schedule_ready_jobs() + + if not self._running: + if not self._ready_heap: + break + # This case (nothing running, but jobs still pending in the queue) can happen + # if backends fail to schedule any jobs (e.g. the backend is temporarily busy). + continue + + # Wait for any job to complete, or for a shutdown signal + try: + done, _ = await asyncio.wait( + (job_completion_task, shutdown_task), + return_when=asyncio.FIRST_COMPLETED, + ) + except asyncio.QueueShutDown as e: + msg = "Scheduler event queue shutdown earlier than expected?" + raise RuntimeError(msg) from e + + if shutdown_task in done: + self._shutdown_event.clear() + shutdown_task = asyncio.create_task(self._shutdown_event.wait()) + await self._handle_exit_signal() + continue + + completions = await self._drain_completions(job_completion_task) + job_completion_task = asyncio.create_task(self._event_queue.get()) + + for event in completions: + job = self._jobs[event.spec.id] + self._mark_job_completed(job, event.status, event.reason) + finally: + job_completion_task.cancel() + shutdown_task.cancel() + + async def _drain_completions(self, completion_task: asyncio.Task) -> list[JobCompletionEvent]: + """Drain batched completions from the queue, optionally coalescing batched events.""" + events = list(completion_task.result()) + + # Coalesce nearby completions by waiting for a very short time + if self.coalesce_window is not None: + await asyncio.sleep(self.coalesce_window) + + # Drain any more completion events from the event queue + try: + while True: + events.extend(self._event_queue.get_nowait()) + except asyncio.QueueEmpty: + return events + except asyncio.QueueShutDown as e: + msg = "Scheduler event queue shutdown earlier than expected?" + raise RuntimeError(msg) from e + + async def _handle_exit_signal(self) -> None: + """Attempt to gracefully shutdown as a result of a triggered exit signal.""" + if self._shutdown_started: + return + self._shutdown_started = True + + signal_name = "SIGTERM" if self._shutdown_signal == SIGTERM else "SIGINT" + log.info("Received %s signal. Exiting gracefully", signal_name) + if self._shutdown_signal == SIGINT: + log.info( + "Send another to force immediate quit (but you may need to manually " + "kill some child processes)." + ) + + self._notify_kill_signal() + + # Mark any jobs that are currently running as jobs we should kill. + # Collect jobs to kill in a dict, grouped per backend, for batched killing. + to_kill: dict[str, list[JobHandle]] = defaultdict(list) + + for job_id in self._running: + job = self._jobs[job_id] + if job.handle is None: + raise RuntimeError("Running job is missing an associated handle.") + job.kill_requested = True + to_kill[job.backend_key].append(job.handle) + + # Asynchronously dispatch backend kill tasks whilst we update scheduler internals. + # Jobs that depend on these jobs will then be transitively killed before they start. + kill_tasks: list[asyncio.Task] = [] + for backend_name, handles in to_kill.items(): + backend = self.backends[backend_name] + kill_tasks.append(asyncio.create_task(backend.kill_many(handles))) + + # Kill any ready (but not running jobs), so that they don't get scheduled. + while self._ready_heap: + _, job_id = heapq.heappop(self._ready_heap) + job = self._jobs[job_id] + self._mark_job_completed(job, JobStatus.KILLED, KILLED_QUEUED) + + if kill_tasks: + await asyncio.gather(*kill_tasks, return_exceptions=True) + + async def _schedule_ready_jobs(self) -> None: + """Attempt to schedule ready jobs whilst respecting scheduler & backend parallelism.""" + # Find out how many jobs we can dispatch according to the scheduler's parallelism limit + available_slots = ( + self.max_parallelism - len(self._running) + if self.max_parallelism + else len(self._ready_heap) + ) + if available_slots <= 0: + return + + # Collect jobs to launch in a dict, grouped per backend, for batched launching. + to_launch: dict[str, list[tuple[Priority, JobRecord]]] = defaultdict(list) + blocked: list[tuple[Priority, str]] = [] + slots_used = 0 + + while self._ready_heap and slots_used < available_slots: + neg_priority, job_id = heapq.heappop(self._ready_heap) + job = self._jobs[job_id] + backend = self.backends[job.backend_key] + running_on_backend = self._running_per_backend[job.backend_key] + len( + to_launch[job.backend_key] + ) + + # Check that we can launch the job whilst respecting backend parallelism limits + if backend.max_parallelism and running_on_backend >= backend.max_parallelism: + blocked.append((neg_priority, job_id)) + continue + + to_launch[job.backend_key].append((neg_priority, job)) + slots_used += 1 + + # Requeue any blocked jobs. + for entry in blocked: + heapq.heappush(self._ready_heap, entry) + + # Launch the selected jobs in batches per backend + launch_tasks = [] + for backend_name, jobs in to_launch.items(): + backend = self.backends[backend_name] + job_specs = [job.spec for _, job in jobs] + log.verbose( + "[%s]: Dispatching jobs: %s", + backend_name, + ", ".join(job.full_name for job in job_specs), + ) + launch_tasks.append(backend.submit_many(job_specs)) + + results = await asyncio.gather(*launch_tasks) + + # Mark jobs running, and requeue any jobs that failed to launch + for jobs, handles in zip(to_launch.values(), results, strict=True): + for neg_priority, job in jobs: + handle = handles.get(job.spec.id) + if handle is None: + log.verbose("[%s]: Requeuing job '%s'", job.spec.target, job.spec.full_name) + heapq.heappush(self._ready_heap, (neg_priority, job.spec.id)) + continue + + job.handle = handle + self._mark_job_running(job) diff --git a/tests/job/test_status.py b/tests/job/test_status.py new file mode 100644 index 00000000..16ff28b2 --- /dev/null +++ b/tests/job/test_status.py @@ -0,0 +1,19 @@ +# Copyright lowRISC contributors (OpenTitan project). +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +"""Test Job (scheduler) status modelling.""" + +from hamcrest import assert_that, equal_to + +from dvsim.job.status import JobStatus + + +class TestJobStatus: + """Test scheduler JobStatus models.""" + + @staticmethod + def test_unique_shorthands() -> None: + """Test that all scheduler job statuses have unique shorthand representations.""" + shorthands = [status.shorthand for status in JobStatus] + assert_that(len(set(shorthands)), equal_to(len(shorthands))) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 042357ed..a443f924 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -23,38 +23,11 @@ from dvsim.job.status import JobStatus from dvsim.launcher.base import ErrorMessage, Launcher, LauncherBusyError, LauncherError from dvsim.report.data import IPMeta, ToolMeta -from dvsim.scheduler.core import Scheduler +from dvsim.runtime.legacy import LegacyLauncherAdapter +from dvsim.scheduler.async_core import Scheduler __all__ = () -# Common reasoning for expected failures to avoid duplication across tests. -# Ideally these will be removed as incorrect behaviour is fixed. -FAIL_DEP_ON_MULTIPLE_TARGETS = """ -DVSim cannot handle dependency fan-in (i.e. depending on jobs) across multiple targets. - -Specifically, when all successors of the first target are initially enqueued, they are -removed from the `scheduled` queues. If any item in another target then also depends -on those items (i.e. across *another* target), then the completion of these items will -in turn attempt to enqueue their own successors, which cannot be found as they are no -longer present in the `scheduled` queues. -""" -FAIL_DEPS_ACROSS_MULTIPLE_TARGETS = ( - "DVSim cannot handle dependency fan-out across multiple targets." -) -FAIL_DEPS_ACROSS_NON_CONSECUTIVE_TARGETS = ( - "DVSim cannot handle dependencies that span non-consecutive (non-adjacent) targets." -) -FAIL_IF_NO_DEPS_WITHOUT_ALL_DEPS_NEEDED = """ -Current DVSim has a strange behaviour where a job with no dependencies is dispatched if it is -marked as needing all its dependencies to pass, but fails (i.e. is killed) if it is marked as -*not* needing all of its dependencies. -""" -FAIL_DEP_OUT_OF_ORDER = """ -DVSim cannot handle jobs given in an order that define dependencies and targets such that, to -resolve the jobs according to those dependencies, the targets must be processed in a different -order to the ordering of the jobs. -""" - # Default scheduler test timeout to handle infinite loops in the scheduler DEFAULT_TIMEOUT = 0.5 @@ -170,7 +143,7 @@ def _do_launch(self) -> None: if mock.launcher_error: raise mock.launcher_error status = mock.current_status - if status == JobStatus.QUEUED: + if status in (JobStatus.SCHEDULED, JobStatus.QUEUED): return # Do not mark as running if still mocking a queued status. self.mock_context.update_running(self.job_spec) @@ -226,6 +199,20 @@ class TestMockLauncher(MockLauncher): return TestMockLauncher +# TODO: we should implement mock runtime backends now that we can give different +# job runtime backends, rather than going through the mock_ctx and mock_launcher +# interfaces. For now, to keep things simple, simply wrap the legacy mock backend +# in the adapter interface. There is value in testing this as well, but ideally we +# also want to test a native mocked runtime backend. +@pytest.fixture +def mock_legacy_backend(mock_launcher: type[MockLauncher]) -> LegacyLauncherAdapter: + """Legacy runtime backend for the mock launcher.""" + return LegacyLauncherAdapter(mock_launcher) + + +MOCK_BACKEND: str = "legacy" + + @dataclass class Fxt: """Collection of fixtures used for mocking and testing the scheduler.""" @@ -233,12 +220,23 @@ class Fxt: tmp_path: Path mock_ctx: MockLauncherContext mock_launcher: type[MockLauncher] + mock_legacy_backend: LegacyLauncherAdapter + + @property + def backends(self) -> dict[str, LegacyLauncherAdapter]: + """Get a backend mapping for the mocked legacy backend.""" + return {MOCK_BACKEND: self.mock_legacy_backend} @pytest.fixture -def fxt(tmp_path: Path, mock_ctx: MockLauncherContext, mock_launcher: type[MockLauncher]) -> Fxt: +def fxt( + tmp_path: Path, + mock_ctx: MockLauncherContext, + mock_launcher: type[MockLauncher], + mock_legacy_backend: LegacyLauncherAdapter, +) -> Fxt: """Fixtures used for mocking and testing the scheduler.""" - return Fxt(tmp_path, mock_ctx, mock_launcher) + return Fxt(tmp_path, mock_ctx, mock_launcher, mock_legacy_backend) def ip_meta_factory(**overrides: str | None) -> IPMeta: @@ -293,7 +291,7 @@ def make_job_paths( log = root / "log.txt" statuses = {} for status in JobStatus: - if status == JobStatus.QUEUED: + if status in (JobStatus.SCHEDULED, JobStatus.QUEUED): continue status_dir = output / status.name.lower() statuses[status] = status_dir @@ -312,6 +310,7 @@ def job_spec_factory( "name": "test_job", "job_type": "mock_type", "target": "mock_target", + "backend": None, "seed": None, "dependencies": [], "needs_all_dependencies_passing": True, @@ -426,31 +425,35 @@ class TestScheduling: """Unit tests for the scheduling decisions of the scheduler.""" @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_empty(fxt: Fxt) -> None: + async def test_empty(fxt: Fxt) -> None: """Test that the scheduler can handle being given no jobs.""" - result = Scheduler([], fxt.mock_launcher).run() + result = await Scheduler([], fxt.backends, MOCK_BACKEND).run() assert_that(result, empty()) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_job_run(fxt: Fxt) -> None: + async def test_job_run(fxt: Fxt) -> None: """Small smoketest that the scheduler can actually run a valid job.""" job = job_spec_factory(fxt.tmp_path) - result = Scheduler([job], fxt.mock_launcher).run() + result = await Scheduler([job], fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 1) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_many_jobs_run(fxt: Fxt) -> None: + async def test_many_jobs_run(fxt: Fxt) -> None: """Smoketest that the scheduler can run multiple valid jobs.""" job_specs = make_many_jobs(fxt.tmp_path, n=5) - result = Scheduler(job_specs, fxt.mock_launcher).run() + result = await Scheduler(job_specs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 5) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_duplicate_jobs(fxt: Fxt) -> None: + async def test_duplicate_jobs(fxt: Fxt) -> None: """Test that the scheduler does not double-schedule jobs with duplicate names.""" workspace = build_workspace(fxt.tmp_path) job_specs = make_many_jobs(fxt.tmp_path, n=3, workspace=workspace) @@ -458,7 +461,7 @@ def test_duplicate_jobs(fxt: Fxt) -> None: for _ in range(10): job_specs.append(job_spec_factory(fxt.tmp_path, name="extra_job")) job_specs.append(job_spec_factory(fxt.tmp_path, name="extra_job_2")) - result = Scheduler(job_specs, fxt.mock_launcher).run() + result = await Scheduler(job_specs, fxt.backends, MOCK_BACKEND).run() # Current behaviour expects duplicate jobs to be *silently ignored*. # We should therefore have 3 + 3 + 2 = 8 jobs. _assert_result_status(result, 8) @@ -467,50 +470,61 @@ def test_duplicate_jobs(fxt: Fxt) -> None: assert_that(len(names), equal_to(len(set(names)))) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize("num_jobs", [2, 3, 5, 10, 20, 100]) - def test_parallel_dispatch(fxt: Fxt, num_jobs: int) -> None: + async def test_parallel_dispatch(fxt: Fxt, num_jobs: int) -> None: """Test that many jobs can be dispatched in parallel.""" jobs = make_many_jobs(fxt.tmp_path, num_jobs) - scheduler = Scheduler(jobs, fxt.mock_launcher) + scheduler = Scheduler(jobs, fxt.backends, MOCK_BACKEND) assert_that(fxt.mock_ctx.max_concurrent, equal_to(0)) - result = scheduler.run() + result = await scheduler.run() _assert_result_status(result, num_jobs) assert_that(fxt.mock_ctx.max_concurrent, equal_to(num_jobs)) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize("num_jobs", [5, 10, 20]) @pytest.mark.parametrize("max_parallel", [1, 5, 15, 25]) - def test_max_parallel(fxt: Fxt, num_jobs: int, max_parallel: int) -> None: - """Test that max parallel limits of launchers are used & respected.""" + @pytest.mark.parametrize("on_scheduler", [True, False]) + async def test_max_parallel( + fxt: Fxt, num_jobs: int, max_parallel: int, *, on_scheduler: bool + ) -> None: + """Test that max parallel limits of launchers & the scheduler are used & respected.""" jobs = make_many_jobs(fxt.tmp_path, num_jobs) - fxt.mock_launcher.max_parallel = max_parallel - scheduler = Scheduler(jobs, fxt.mock_launcher) + if on_scheduler: + fxt.mock_legacy_backend.max_parallelism = 0 + scheduler = Scheduler(jobs, fxt.backends, MOCK_BACKEND, max_parallelism=max_parallel) + else: + fxt.mock_legacy_backend.max_parallelism = max_parallel + scheduler = Scheduler(jobs, fxt.backends, MOCK_BACKEND) assert_that(fxt.mock_ctx.max_concurrent, equal_to(0)) - result = scheduler.run() + result = await scheduler.run() _assert_result_status(result, num_jobs) assert_that(fxt.mock_ctx.max_concurrent, equal_to(min(num_jobs, max_parallel))) @staticmethod + @pytest.mark.asyncio + @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize("polls", [5, 10, 50]) @pytest.mark.parametrize("final_status", [JobStatus.PASSED, JobStatus.FAILED, JobStatus.KILLED]) - @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_repeated_poll(fxt: Fxt, polls: int, final_status: JobStatus) -> None: + async def test_repeated_poll(fxt: Fxt, polls: int, final_status: JobStatus) -> None: """Test that the scheduler will repeatedly poll for a dispatched job.""" job = job_spec_factory(fxt.tmp_path) fxt.mock_ctx.set_config( job, MockJob(status_thresholds=[(0, JobStatus.RUNNING), (polls, final_status)]) ) - result = Scheduler([job], fxt.mock_launcher).run() + result = await Scheduler([job], fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 1, expected=final_status) config = fxt.mock_ctx.get_config(job) if config is not None: assert_that(config.poll_count, equal_to(polls)) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_no_over_poll(fxt: Fxt) -> None: + async def test_no_over_poll(fxt: Fxt) -> None: """Test that the schedule stops polling when it sees `PASSED`, and does not over-poll.""" jobs = make_many_jobs(fxt.tmp_path, 10) polls = [(i + 1) * 10 for i in range(10)] @@ -519,7 +533,7 @@ def test_no_over_poll(fxt: Fxt) -> None: jobs[i], MockJob(status_thresholds=[(0, JobStatus.RUNNING), (polls[i], JobStatus.PASSED)]), ) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 10) # Check we do not unnecessarily over-poll the jobs for i in range(10): @@ -528,13 +542,8 @@ def test_no_over_poll(fxt: Fxt) -> None: assert_that(config.poll_count, equal_to(polls[i])) @staticmethod - @pytest.mark.xfail( - reason="DVSim currently errors on this case. When DVSim dispatches and thus launches a" - " job, it is only set to running after the launch. If a launcher error occurs, it" - " immediately invokes `_kill_item` which tries to remove it from the list of running jobs" - " (where it does not exist)." - ) - def test_launcher_error(fxt: Fxt) -> None: + @pytest.mark.asyncio + async def test_launcher_error(fxt: Fxt) -> None: """Test that the launcher correctly handles an error during job launching.""" job = job_spec_factory(fxt.tmp_path, paths=make_job_paths(fxt.tmp_path, ensure_exists=True)) fxt.mock_ctx.set_config( @@ -544,13 +553,14 @@ def test_launcher_error(fxt: Fxt) -> None: launcher_error=LauncherError("abc"), ), ) - result = Scheduler([job], fxt.mock_launcher).run() + result = await Scheduler([job], fxt.backends, MOCK_BACKEND).run() # On a launcher error, the job has failed and should be killed. _assert_result_status(result, 1, expected=JobStatus.KILLED) @staticmethod + @pytest.mark.asyncio @pytest.mark.parametrize("busy_polls", [1, 2, 5, 10]) - def test_launcher_busy_error(fxt: Fxt, busy_polls: int) -> None: + async def test_launcher_busy_error(fxt: Fxt, busy_polls: int) -> None: """Test that the launcher correctly handles the launcher busy case.""" job = job_spec_factory(fxt.tmp_path) err_mock = (busy_polls, LauncherBusyError("abc")) @@ -561,7 +571,7 @@ def test_launcher_busy_error(fxt: Fxt, busy_polls: int) -> None: launcher_busy_error=err_mock, ), ) - result = Scheduler([job], fxt.mock_launcher).run() + result = await Scheduler([job], fxt.backends, MOCK_BACKEND).run() # We expect to have successfully launched and ran, eventually. _assert_result_status(result, 1) # Check that the scheduler tried to `launch()` the correct number of times. @@ -577,25 +587,17 @@ class TestSchedulingStructure: """ @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - @pytest.mark.parametrize( - "needs_all_passing", - [ - True, - pytest.param( - False, - marks=pytest.mark.xfail(reason=FAIL_IF_NO_DEPS_WITHOUT_ALL_DEPS_NEEDED), - ), - ], - ) - def test_no_deps(fxt: Fxt, *, needs_all_passing: bool) -> None: + @pytest.mark.parametrize("needs_all_passing", [True, False]) + async def test_no_deps(fxt: Fxt, *, needs_all_passing: bool) -> None: """Tests scheduling of jobs without any listed dependencies.""" job = job_spec_factory(fxt.tmp_path, needs_all_dependencies_passing=needs_all_passing) - result = Scheduler([job], fxt.mock_launcher).run() + result = await Scheduler([job], fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 1) @staticmethod - def _dep_test_case( + async def _dep_test_case( fxt: Fxt, dep_list: dict[int, list[int]], passes: list[int], @@ -611,7 +613,7 @@ def _dep_test_case( ) fxt.mock_ctx.set_config(jobs[2], MockJob(default_status=JobStatus.FAILED)) fxt.mock_ctx.set_config(jobs[4], MockJob(default_status=JobStatus.FAILED)) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() assert_that(len(result), equal_to(5)) for job in range(5): if job in passes: @@ -623,9 +625,7 @@ def _dep_test_case( assert_that(result[job].status, equal_to(expected)) @staticmethod - @pytest.mark.xfail( - reason=FAIL_DEP_ON_MULTIPLE_TARGETS + " " + FAIL_IF_NO_DEPS_WITHOUT_ALL_DEPS_NEEDED - ) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize( ("dep_list", "passes"), @@ -637,16 +637,16 @@ def _dep_test_case( ({0: [1, 2, 3, 4]}, [0, 1, 3]), ], ) - def test_needs_any_dep( + async def test_needs_any_dep( fxt: Fxt, dep_list: dict[int, list[int]], passes: list[int], ) -> None: """Tests scheduling of jobs with dependencies that don't need all passing.""" - TestSchedulingStructure._dep_test_case(fxt, dep_list, passes, all_passing=False) + await TestSchedulingStructure._dep_test_case(fxt, dep_list, passes, all_passing=False) @staticmethod - @pytest.mark.xfail(reason=FAIL_DEP_ON_MULTIPLE_TARGETS) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize( ("dep_list", "passes"), @@ -659,19 +659,16 @@ def test_needs_any_dep( ({1: [0, 2, 3, 4]}, [0, 3]), ], ) - def test_needs_all_deps( + async def test_needs_all_deps( fxt: Fxt, dep_list: dict[int, list[int]], passes: list[int], ) -> None: """Tests scheduling of jobs with dependencies that need all passing.""" - TestSchedulingStructure._dep_test_case(fxt, dep_list, passes, all_passing=True) + await TestSchedulingStructure._dep_test_case(fxt, dep_list, passes, all_passing=True) @staticmethod - @pytest.mark.xfail( - reason="DVSim does not currently have logic to detect and error on" - "dependency cycles within provided job specifications." - ) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize( ("dep_list"), @@ -683,19 +680,18 @@ def test_needs_all_deps( {0: [1, 2, 3, 4], 1: [2, 3, 4], 2: [3, 4], 3: [4], 4: [0]}, ], ) - def test_dep_cycle(fxt: Fxt, dep_list: dict[int, list[int]]) -> None: + async def test_dep_cycle(fxt: Fxt, dep_list: dict[int, list[int]]) -> None: """Test that the scheduler can detect and handle cycles in dependencies.""" jobs = make_many_jobs(fxt.tmp_path, 5, interdeps=dep_list) # Expect that we get a ValueError when trying to make the scheduler, # due to the cycle(s) in the dependencies assert_that( - calling(Scheduler).with_args(jobs, fxt.mock_launcher), raises(ValueError, "cycle") + calling(Scheduler).with_args(jobs, fxt.backends, MOCK_BACKEND), + raises(ValueError, "cycle"), ) @staticmethod - @pytest.mark.xfail( - reason=FAIL_DEP_ON_MULTIPLE_TARGETS + " " + FAIL_DEPS_ACROSS_MULTIPLE_TARGETS - ) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize( ("dep_list"), @@ -706,34 +702,25 @@ def test_dep_cycle(fxt: Fxt, dep_list: dict[int, list[int]]) -> None: {0: [1, 2, 3, 4], 1: [2], 3: [2, 4], 4: [2]}, ], ) - def test_dep_resolution(fxt: Fxt, dep_list: dict[int, list[int]]) -> None: + async def test_dep_resolution(fxt: Fxt, dep_list: dict[int, list[int]]) -> None: """Test that the scheduler can correctly resolve complex job dependencies.""" jobs = make_many_jobs(fxt.tmp_path, 5, interdeps=dep_list) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 5) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_deps_across_polls(fxt: Fxt) -> None: + async def test_deps_across_polls(fxt: Fxt) -> None: """Test that the scheduler can resolve multiple deps that complete at different times.""" - jobs = make_many_jobs(fxt.tmp_path, 4) - # For now, define the end job separately so that we can put it in a different target - # but keep the other jobs in the same target (to circumvent FAIL_DEP_ON_MULTIPLE_TARGETS). - jobs.append( - job_spec_factory( - fxt.tmp_path, - name="end", - dependencies=[job.name for job in jobs], - target="end_target", - ) - ) + jobs = make_many_jobs(fxt.tmp_path, 5, interdeps={4: [0, 1, 2, 3]}) polls = [i * 5 for i in range(5)] for i in range(1, 5): fxt.mock_ctx.set_config( jobs[i], MockJob(status_thresholds=[(0, JobStatus.RUNNING), (polls[i], JobStatus.PASSED)]), ) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 5) # Sanity check that we did poll each job the correct number of times as well for i in range(1, 5): @@ -742,101 +729,93 @@ def test_deps_across_polls(fxt: Fxt) -> None: assert_that(config.poll_count, equal_to(polls[i])) @staticmethod - @pytest.mark.xfail( - reason="DVSim currently implicitly assumes that job with/in other targets" - " will be reachable (i.e. transitive) dependencies of jobs in the first target." - ) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_multiple_targets(fxt: Fxt) -> None: + async def test_multiple_targets(fxt: Fxt) -> None: """Test that the scheduler can handle jobs across many targets.""" # Create 15 jobs across 5 targets (3 jobs per target), with no dependencies. jobs = make_many_jobs(fxt.tmp_path, 15, per_job=lambda i: {"target": f"target_{i // 3}"}) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 15) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize("num_deps", range(2, 6)) - def test_cross_target_deps(fxt: Fxt, num_deps: int) -> None: + async def test_cross_target_deps(fxt: Fxt, num_deps: int) -> None: """Test that the scheduler can handle dependencies across targets.""" deps = {i: [i - 1] for i in range(1, num_deps)} jobs = make_many_jobs(fxt.tmp_path, num_deps, interdeps=deps, vary_targets=True) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, num_deps) @staticmethod - @pytest.mark.xfail(reason=FAIL_DEP_ON_MULTIPLE_TARGETS) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize("num_deps", range(2, 6)) - def test_dep_fan_in(fxt: Fxt, num_deps: int) -> None: + async def test_dep_fan_in(fxt: Fxt, num_deps: int) -> None: """Test that job dependencies can fan-in from multiple other jobs.""" num_jobs = num_deps + 1 deps = {0: list(range(1, num_jobs))} jobs = make_many_jobs(fxt.tmp_path, num_jobs, interdeps=deps) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, num_jobs) @staticmethod - @pytest.mark.xfail(reason=FAIL_DEPS_ACROSS_MULTIPLE_TARGETS) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize("num_deps", range(2, 6)) - def test_dep_fan_out(fxt: Fxt, num_deps: int) -> None: + async def test_dep_fan_out(fxt: Fxt, num_deps: int) -> None: """Test that job dependencies can fan-out to multiple other jobs.""" num_jobs = num_deps + 1 deps = {i: [num_deps] for i in range(num_deps)} jobs = make_many_jobs(fxt.tmp_path, num_jobs, interdeps=deps, vary_targets=True) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, num_jobs) @staticmethod - @pytest.mark.xfail(reason=FAIL_DEPS_ACROSS_NON_CONSECUTIVE_TARGETS) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_non_consecutive_targets(fxt: Fxt) -> None: + async def test_non_consecutive_targets(fxt: Fxt) -> None: """Test that jobs can have non-consecutive dependencies (deps in non-adjacent targets).""" jobs = make_many_jobs(fxt.tmp_path, 4, interdeps={3: [0]}, vary_targets=True) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 4) @staticmethod - @pytest.mark.xfail(reason=FAIL_DEP_OUT_OF_ORDER) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_target_out_of_order(fxt: Fxt) -> None: + async def test_target_out_of_order(fxt: Fxt) -> None: """Test that the scheduler can handle targets being given out-of-dependency-order.""" jobs = make_many_jobs(fxt.tmp_path, 4, interdeps={1: [0], 2: [3]}, vary_targets=True) # First test jobs 0 and 1 (0 -> 1). Then test jobs 2 and 3 (2 <- 3). for order in (jobs[:2], jobs[2:]): - result = Scheduler(order, fxt.mock_launcher).run() + result = await Scheduler(order, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 2) - # TODO: it isn't clear if this is a feature that DVSim should actually support. - # If Job specifications can form any DAG where targets are essentially just vertex - # labels/groups, then it makes sense that we can support a target-/layer-annotated - # specification with "bi-directional" edges. If layers are structural and intended - # to be monotonically increasing, this test should be changed / removed. For now, - # we test as if the former is the intended behaviour. @staticmethod - @pytest.mark.xfail(reason="DVSim cannot currently handle this case.") + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_bidirectional_deps(fxt: Fxt) -> None: + async def test_bidirectional_deps(fxt: Fxt) -> None: """Test that the scheduler handles bidirectional cross-target deps.""" # job_0 (target_0) -> job_1 (target_1) -> job_2 (target_0) targets = ["target_0", "target_1", "target_0"] jobs = make_many_jobs( fxt.tmp_path, 3, interdeps={0: [1], 1: [2]}, per_job=lambda i: {"target": targets[i]} ) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, 3) @staticmethod + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) @pytest.mark.parametrize("error_status", [JobStatus.FAILED, JobStatus.KILLED]) - def test_dep_fail_propagation(fxt: Fxt, error_status: JobStatus) -> None: + async def test_dep_fail_propagation(fxt: Fxt, error_status: JobStatus) -> None: """Test that failures in job dependencies propagate.""" - # Note: job order is due to working around FAIL_DEP_OUT_OF_ORDER. deps = {i: [i - 1] for i in range(1, 5)} - jobs = make_many_jobs(fxt.tmp_path, n=5, interdeps=deps, vary_targets=True) + jobs = make_many_jobs(fxt.tmp_path, n=5, interdeps=deps) fxt.mock_ctx.set_config(jobs[0], MockJob(default_status=error_status)) - result = Scheduler(jobs, fxt.mock_launcher).run() + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() assert_that(len(result), equal_to(5)) # The job that we configured to error should show the error status assert_that(result[0].status, equal_to(error_status)) @@ -848,12 +827,10 @@ class TestSchedulingPriority: """Unit tests for scheduler decisions related to job/target weighting/priority.""" @staticmethod - @pytest.mark.xfail( - reason=FAIL_DEPS_ACROSS_MULTIPLE_TARGETS + " " + FAIL_DEPS_ACROSS_NON_CONSECUTIVE_TARGETS - ) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_job_priority(fxt: Fxt) -> None: - """Test that jobs across targets are prioritised according to their weight.""" + async def test_job_priority(fxt: Fxt) -> None: + """Test that jobs across targets are prioritised according to their weight by default.""" start_job = job_spec_factory(fxt.tmp_path, name="start") weighted_jobs = make_many_jobs( fxt.tmp_path, @@ -866,28 +843,26 @@ def test_job_priority(fxt: Fxt) -> None: by_weight_dec = sorted(weighted_jobs, key=lambda job: job.weight, reverse=True) # Set max parallel = 1 so that order dispatched becomes the priority order # With max parallel > 1, jobs of many priorities are dispatched "at once". - fxt.mock_launcher.max_parallel = 1 - result = Scheduler(jobs, fxt.mock_launcher).run() + fxt.mock_legacy_backend.max_parallelism = 1 + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() _assert_result_status(result, len(jobs)) expected_order = [start_job, *by_weight_dec] assert_that(fxt.mock_ctx.order_started, equal_to(expected_order)) @staticmethod - @pytest.mark.xfail(reason="DVSim does not handle zero weights.") + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_zero_weight(fxt: Fxt) -> None: + async def test_zero_weight(fxt: Fxt) -> None: """Test that the scheduler can handle the case where jobs have a total weight of zero.""" jobs = make_many_jobs(fxt.tmp_path, 5, weight=0) - result = Scheduler(jobs, fxt.mock_launcher).run() - # TODO: not clear if this should evenly distribute and succeed, or error. + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND).run() + # Zero weight should just mark a job as the lowest priority, but the jobs should still run. _assert_result_status(result, 5) @staticmethod - @pytest.mark.xfail( - reason=FAIL_DEPS_ACROSS_MULTIPLE_TARGETS + " " + FAIL_DEPS_ACROSS_NON_CONSECUTIVE_TARGETS - ) + @pytest.mark.asyncio @pytest.mark.timeout(DEFAULT_TIMEOUT) - def test_blocked_weight_starvation(fxt: Fxt) -> None: + async def test_blocked_weight_starvation(fxt: Fxt) -> None: """Test that high weight jobs without fulfilled deps do not block lower weight jobs.""" # All jobs spawn from a start job. # There is one chain "start -> long_blocker -> high" where we have a high weight job @@ -915,30 +890,39 @@ def test_blocked_weight_starvation(fxt: Fxt) -> None: long_blocker, MockJob(status_thresholds=[(0, JobStatus.RUNNING), (5, JobStatus.PASSED)]), ) - result = Scheduler(jobs, fxt.mock_launcher).run() - _assert_result_status(result, 8) + # Do not coalesce nearby events, as otherwise the blockers may complete close + # enough with a low/zero polling frequency that they get batched and the + # high priority job is scheduled first. + result = await Scheduler(jobs, fxt.backends, MOCK_BACKEND, coalesce_window=None).run() + _assert_result_status(result, len(jobs)) # We expect that the high weight job should have been scheduled last, since # it was blocked by the blocker (unlike all the other lower weight jobs) assert_that(fxt.mock_ctx.order_started[0], equal_to(start_job)) assert_that(fxt.mock_ctx.order_started[-1], equal_to(high)) - # TODO: we do not currently test the logic to schedule multiple queued jobs per target - # across different targets based on the weights of those jobs/targets, because this - # will require the test to be quite complex and specific to the intricacies of the - # current DVSim scheduler due to the current implementation. Due to only one successor - # in another target being discovered at once, we must carefully construct a dependency - # tree of jobs with specially modelled delays which relies on this implementation - # detail. Instead, for now at least, we leave this untested. - # - # Note also that DVSim currently assumes weights within a target are constant, - # which may not be the case with the current JobSpec model. + @staticmethod + @pytest.mark.asyncio + @pytest.mark.timeout(DEFAULT_TIMEOUT) + async def test_custom_priority(fxt: Fxt) -> None: + """Test that a custom prioritization function can be given to and used by the scheduler.""" + jobs = make_many_jobs( + fxt.tmp_path, n=5, per_job=lambda n: {"name": str(n), "weight": n + 1} + ) + # Prioritizes jobs via their names (lower names have higher priority, so come first). + # So jobs should be scheduled in the order created, instead of the opposite default order + # by decreasing weight. + result = await Scheduler( + jobs, fxt.backends, MOCK_BACKEND, priority_fn=lambda job: -int(job.spec.name) + ).run() + _assert_result_status(result, len(jobs)) + assert_that(fxt.mock_ctx.order_started, equal_to(jobs)) class TestSignals: """Integration tests for the signal-handling of the scheduler.""" @staticmethod - def _run_signal_test(tmp_path: Path, sig: int, *, repeat: bool, long_poll: bool) -> None: + async def _run_signal_test(tmp_path: Path, sig: int, *, repeat: bool, long_poll: bool) -> None: """Test that the scheduler can be gracefully killed by incoming signals.""" # We cannot access the fixtures from the separate process, so define a minimal @@ -954,6 +938,9 @@ class SignalTestMockLauncher(MockLauncher): # scheduler from a sleep if configured with infrequent polls. SignalTestMockLauncher.poll_freq = 360000 + # TODO: use a mocked runtime backend instead of a wrapper around the launcher + backend = LegacyLauncherAdapter(SignalTestMockLauncher) + jobs = make_many_jobs(tmp_path, 3, ensure_paths_exist=True) # When testing non-graceful exits, we make `kill()` hang and send two signals. kill_time = None if not repeat else 100.0 @@ -970,7 +957,7 @@ class SignalTestMockLauncher(MockLauncher): # Job 2 is also permanently "running", but will never run due to the # max paralellism limit on the launcher. It will instead be cancelled. mock_ctx.set_config(jobs[2], MockJob(default_status=JobStatus.RUNNING, kill_time=kill_time)) - scheduler = Scheduler(jobs, SignalTestMockLauncher) + scheduler = Scheduler(jobs, {MOCK_BACKEND: backend}, MOCK_BACKEND) def _get_signal(sig_received: int, _: FrameType | None) -> None: assert_that(sig_received, equal_to(sig)) @@ -996,7 +983,7 @@ def _send_signals() -> None: # Send signals from a separate thread threading.Thread(target=_send_signals).start() - result = scheduler.run() + result = await scheduler.run() # If we didn't reach `_get_signal`, this should be a graceful exit assert_that(not repeat) diff --git a/uv.lock b/uv.lock index 68a9fc53..9fc9de9a 100644 --- a/uv.lock +++ b/uv.lock @@ -46,6 +46,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "blessed" version = "1.33.0" @@ -378,6 +387,7 @@ ci = [ { name = "pyhamcrest" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-repeat" }, { name = "pytest-timeout" }, @@ -396,6 +406,7 @@ dev = [ { name = "pyhamcrest" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-repeat" }, { name = "pytest-timeout" }, @@ -413,6 +424,7 @@ nix = [ { name = "pyhamcrest" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-repeat" }, { name = "pytest-timeout" }, @@ -424,6 +436,7 @@ release = [ test = [ { name = "pyhamcrest" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-repeat" }, { name = "pytest-timeout" }, @@ -452,6 +465,7 @@ requires-dist = [ { name = "pyhamcrest", marker = "extra == 'test'", specifier = ">=2.1.0" }, { name = "pyright", marker = "extra == 'typing'", specifier = ">=1.1.381" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=1.3.0" }, { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=5.0.0" }, { name = "pytest-repeat", marker = "extra == 'test'", specifier = ">=0.9.4" }, { name = "pytest-timeout", marker = "extra == 'test'", specifier = ">=2.4.0" }, @@ -1104,6 +1118,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "pytest-cov" version = "7.0.0"