Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 48 additions & 33 deletions src/dvsim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"""Job scheduler."""

import contextlib
import threading
import os
import selectors
from collections.abc import (
Callable,
Mapping,
MutableMapping,
MutableSequence,
Expand Down Expand Up @@ -164,63 +166,76 @@ def __init__(
# variant-specific settings such as max parallel jobs & poll rate.
self._launcher_cls: type[Launcher] = launcher_cls

def run(self) -> Sequence[CompletedJobStatus]:
"""Run all scheduled jobs and return the results.
def _handle_exit_signal(self, last_received_signal: int, handler: Callable) -> None:
"""Handle a received exit (SIGINT/SIGTERM signal) in the main scheduler loop.

Returns the results (status) of all items dispatched for all
targets and cfgs.
On either signal, this will tell runners to quit and cancel future jobs.
On receiving a SIGINT specifically, this re-installs the old signal handler
such that subsequent SIGINT signals will kill the process (non-gracefully).
"""
timer = Timer()

# Catch one SIGINT and tell the runner to quit. On a second, die.
stop_now = threading.Event()
old_handler = None
log.info(
"Received signal %s. Exiting gracefully.",
last_received_signal,
)

def on_signal(signal_received: int, _: FrameType | None) -> None:
if last_received_signal == SIGINT:
log.info(
"Received signal %s. Exiting gracefully.",
signal_received,
"Send another to force immediate quit (but you may "
"need to manually kill child processes)",
)

if signal_received == SIGINT:
log.info(
"Send another to force immediate quit (but you may "
"need to manually kill child processes)",
)
# Restore old handler to catch a second SIGINT
signal(SIGINT, handler)

# Restore old handler to catch a second SIGINT
if old_handler is None:
raise RuntimeError("Old SIGINT handler not found")
self._kill()

signal(signal_received, old_handler)
def run(self) -> Sequence[CompletedJobStatus]:
"""Run all scheduled jobs and return the results.

stop_now.set()
Returns the results (status) of all items dispatched for all
targets and cfgs.
"""
timer = Timer()

old_handler = signal(SIGINT, on_signal)
# On SIGTERM or SIGINT, tell the runner to quit.
# On a second SIGINT specifically, die.
sel = selectors.DefaultSelector()
signal_rfd, signal_wfd = os.pipe()
sel.register(signal_rfd, selectors.EVENT_READ)
last_received_signal: int | None = None

# Install the SIGTERM handler before scheduling jobs.
def on_signal(signal_received: int, _: FrameType | None) -> None:
# To allow async-safe-signal logic where signals can be handled
# while sleeping, we use a selector to perform a blocking wait,
# and signal the event through a pipe. We then set a flag with
# the received signal. Like this, we can receive a signal
# at any point, and it can also interrupt the poll wait to take
# immediate effect.
nonlocal last_received_signal
last_received_signal = signal_received
os.write(signal_wfd, b"\x00")

# Install the SIGINT and SIGTERM handlers before scheduling jobs.
old_handler = signal(SIGINT, on_signal)
signal(SIGTERM, on_signal)

# Enqueue all items of the first target.
self._enqueue_successors(None)

try:
while True:
if stop_now.is_set():
# We've had an interrupt. Kill any jobs that are running.
self._kill()
if last_received_signal is not None:
self._handle_exit_signal(last_received_signal, old_handler)
last_received_signal = None

hms = timer.hms()
changed = self._poll(hms) or timer.check_time()
self._dispatch(hms)
if changed and self._check_if_done(hms):
break

# This is essentially sleep(1) to wait a second between each
# polling loop. But we do it with a bounded wait on stop_now so
# that we jump back to the polling loop immediately on a
# signal.
stop_now.wait(timeout=self._launcher_cls.poll_freq)
# Wait between each poll, except we may be woken by a signal.
sel.select(timeout=self._launcher_cls.poll_freq)

finally:
signal(SIGINT, old_handler)
Expand Down
6 changes: 0 additions & 6 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,12 +1011,6 @@ def _send_signals() -> None:
_assert_result_status(result, 3, expected=JobStatus.KILLED)

@staticmethod
@pytest.mark.xfail(
reason="This test passes ~95 percent of the time, but the logging & threading primitive"
"logic used in the signal handler are not async-signal-safe and thus may deadlock,"
"causing the process to hang and time out instead.",
strict=False,
)
@pytest.mark.parametrize("long_poll", [False, True])
@pytest.mark.parametrize(("sig", "repeat"), [(SIGTERM, False), (SIGINT, False), (SIGINT, True)])
def test_signal_kill(tmp_path: Path, *, sig: int, repeat: bool, long_poll: bool) -> None:
Expand Down
Loading