diff --git a/src/dvsim/scheduler.py b/src/dvsim/scheduler.py index 515c860..766778a 100644 --- a/src/dvsim/scheduler.py +++ b/src/dvsim/scheduler.py @@ -5,8 +5,10 @@ """Job scheduler.""" import contextlib -import threading +import os +import selectors from collections.abc import ( + Callable, Mapping, MutableMapping, MutableSequence, @@ -164,41 +166,57 @@ def __init__( # variant-specific settings such as max parallel jobs & poll rate. self._launcher_cls: type[Launcher] = launcher_cls - def run(self) -> Sequence[CompletedJobStatus]: - """Run all scheduled jobs and return the results. + def _handle_exit_signal(self, last_received_signal: int, handler: Callable) -> None: + """Handle a received exit (SIGINT/SIGTERM signal) in the main scheduler loop. - Returns the results (status) of all items dispatched for all - targets and cfgs. + On either signal, this will tell runners to quit and cancel future jobs. + On receiving a SIGINT specifically, this re-installs the old signal handler + such that subsequent SIGINT signals will kill the process (non-gracefully). """ - timer = Timer() - - # Catch one SIGINT and tell the runner to quit. On a second, die. - stop_now = threading.Event() - old_handler = None + log.info( + "Received signal %s. Exiting gracefully.", + last_received_signal, + ) - def on_signal(signal_received: int, _: FrameType | None) -> None: + if last_received_signal == SIGINT: log.info( - "Received signal %s. Exiting gracefully.", - signal_received, + "Send another to force immediate quit (but you may " + "need to manually kill child processes)", ) - if signal_received == SIGINT: - log.info( - "Send another to force immediate quit (but you may " - "need to manually kill child processes)", - ) + # Restore old handler to catch a second SIGINT + signal(SIGINT, handler) - # Restore old handler to catch a second SIGINT - if old_handler is None: - raise RuntimeError("Old SIGINT handler not found") + self._kill() - signal(signal_received, old_handler) + def run(self) -> Sequence[CompletedJobStatus]: + """Run all scheduled jobs and return the results. - stop_now.set() + Returns the results (status) of all items dispatched for all + targets and cfgs. + """ + timer = Timer() - old_handler = signal(SIGINT, on_signal) + # On SIGTERM or SIGINT, tell the runner to quit. + # On a second SIGINT specifically, die. + sel = selectors.DefaultSelector() + signal_rfd, signal_wfd = os.pipe() + sel.register(signal_rfd, selectors.EVENT_READ) + last_received_signal: int | None = None - # Install the SIGTERM handler before scheduling jobs. + def on_signal(signal_received: int, _: FrameType | None) -> None: + # To allow async-safe-signal logic where signals can be handled + # while sleeping, we use a selector to perform a blocking wait, + # and signal the event through a pipe. We then set a flag with + # the received signal. Like this, we can receive a signal + # at any point, and it can also interrupt the poll wait to take + # immediate effect. + nonlocal last_received_signal + last_received_signal = signal_received + os.write(signal_wfd, b"\x00") + + # Install the SIGINT and SIGTERM handlers before scheduling jobs. + old_handler = signal(SIGINT, on_signal) signal(SIGTERM, on_signal) # Enqueue all items of the first target. @@ -206,9 +224,9 @@ def on_signal(signal_received: int, _: FrameType | None) -> None: try: while True: - if stop_now.is_set(): - # We've had an interrupt. Kill any jobs that are running. - self._kill() + if last_received_signal is not None: + self._handle_exit_signal(last_received_signal, old_handler) + last_received_signal = None hms = timer.hms() changed = self._poll(hms) or timer.check_time() @@ -216,11 +234,8 @@ def on_signal(signal_received: int, _: FrameType | None) -> None: if changed and self._check_if_done(hms): break - # This is essentially sleep(1) to wait a second between each - # polling loop. But we do it with a bounded wait on stop_now so - # that we jump back to the polling loop immediately on a - # signal. - stop_now.wait(timeout=self._launcher_cls.poll_freq) + # Wait between each poll, except we may be woken by a signal. + sel.select(timeout=self._launcher_cls.poll_freq) finally: signal(SIGINT, old_handler) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index df7a961..91667e8 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1011,12 +1011,6 @@ def _send_signals() -> None: _assert_result_status(result, 3, expected=JobStatus.KILLED) @staticmethod - @pytest.mark.xfail( - reason="This test passes ~95 percent of the time, but the logging & threading primitive" - "logic used in the signal handler are not async-signal-safe and thus may deadlock," - "causing the process to hang and time out instead.", - strict=False, - ) @pytest.mark.parametrize("long_poll", [False, True]) @pytest.mark.parametrize(("sig", "repeat"), [(SIGTERM, False), (SIGINT, False), (SIGINT, True)]) def test_signal_kill(tmp_path: Path, *, sig: int, repeat: bool, long_poll: bool) -> None: