diff --git a/docs/contributor_guide/index.md b/docs/contributor_guide/index.md index 06292a05ccd..4e668b6f0eb 100644 --- a/docs/contributor_guide/index.md +++ b/docs/contributor_guide/index.md @@ -7,6 +7,7 @@ contributing how-to-open-source contributing-to-docs +structured-concurrency ```
diff --git a/docs/contributor_guide/structured-concurrency.md b/docs/contributor_guide/structured-concurrency.md new file mode 100644 index 00000000000..8851ad621f4 --- /dev/null +++ b/docs/contributor_guide/structured-concurrency.md @@ -0,0 +1,84 @@ +# Structured Concurrency in PyLabRobot + +## API + +In PyLabRobot, all asynchronous resources expose the `pylabrobot.concurrency.AsyncResource` API: Resources are usable exactly within the body of `async with resource:`. +What exactly *usable* means may depend on the resource though, +as some functionality *may* be available outside the `async with` block too. +Unless that is specified by the API for a specific resource, you should not rely on it. + +### Implementing `AsyncResource` + +When implementing `AsyncResource` for a new class, you should not write `__aenter__` and `__aexit__` directly, as this is difficult to get right. +Instead, you should implement the `_lifespan` async context manager. +It is often most convenient to do so in terms of a `contextlib.AsyncExitStack`, +so the default implementation of `_lifespan` does that and delegates to a `_enter_lifespan(stack)` coroutine. +There is no `_exit_lifespan` (because separate enter and exit calls are the antithesis of structured concurrency), +instead, all cleanup is registered with the `stack`. + +### Legacy `setup`/`stop` calls + +For historical reasons and to support certain interactive use-cases, +we still expose a `setup`/`stop` API in subclasses of `Machine`. +Note however that, with this API, you give away control over the scope of the async work: For example, there is no way to reliably catch all errors in background tasks, or to handle cancellation of tasks consistently. Do not use that in production scripts. + +## Testing + +Previous testing within PLR relied on `unittest.IsolatedAsyncioTestCase`. +Unfortunately, the `unittest` paradigm is fundamentally incompatible with structured concurrency. +There is no structured scope enclosing the tests, and all attempts to work around this failed. + +Instead, we provide `pylabrobot.testing.concurrency.AnyioTestBase`. +This is *not* a `unittest.TestCase` on purpose, in order not to trigger `pytest`'s +`unittest` compatibility mode. It *does* however reimplement the asserts from `unittest`, +as to streamline test conversion. +Test cases can be left as-is, but the `setUp`/`asyncSetUp` / `tearDown`/`asyncTearDown` logic needs to be replaced by a `_lifespan` or `_enter_lifespan` implementation (it is a `AsyncResource` itself). + +### Gotchas: +- `unittest.AsyncMock` creates `async` methods that do never yield. + This is a problem if they are used in a tight loop, with no other yield point; + leading to a deadlock. This appears in the wild in reader loops of I/O plumbing, + so we provide `pylabrobot.testing.mock_io.MockIO` as a more focussed alternative. + +## Notes from the refactor: +- Timeout semantics may have changed slightly. Usually, that's the case because previous + timeout semantics are often confusing or ill specified (because without structured concurrency, + it's very hard to implement good timeout semantics). We tried to stay as close as possible to the previous semantics. That said, going forward, one `timeout` arguments should always be a trigger to take a step back and think about semantics: Is it supposed to be a timeout on the full operation? Then, *don't* put a `timeout` argument at all! Users are better served by wrapping + *the whole operation* with `with anyio.fail_after`. If the timeout somehow applies to sub-parts, + then be very careful in specifying to what they apply (and what is being done if timeouts fail). + +## Limitations: + - The Opentrons thermocycler USB backend is `asyncio`-only. + +## Issues found during the refactor + +### Unstructured start/stop behaviours that might be better off as context manager +- `shake` and `stop_shaking` on Agilent Biotek. + +### Inconsistent "turn-off" behaviout of various machines. +Most machines seem to turn off any ongoing actions and go back to some form of "parking position", but other machines don't: +- Tecan EVO has a number of arms that one could park; currently, we don't. + + +## TODOs in the refactor + +### References to `setup` + - Developer docs + - Many error messages + - `.setup_done()` calls + +### References to `unittest` + - Async tests now *require* pytest - let's remove all calls to `unittest.main()` + +### Check for other signs that are frowned upon with structured concurrency: + - Anything involving `time.time()` or `time.monotonic()` - should at least be `anyio.current_time()`, but often is a sign for a busy-loop or manual timeout handling. + - Check for use of `threading`. + - Check for use of `asyncio` - avoid raw `asyncio` APIs, should all be converted to `anyio` or something else that is loop-agnostic. + +### Verification checks for changes already made + - `_enter_lifespan` extra arguments other than `stack` should be *keyword-only*! + - Have a look at all `stack.push_async_callback`, especially for `cleanup()` functions - these could often in fact be sync. + - Verify that all cleanup logic has cancellation-shielding in place where necessary. + +### Things to watch out for +- We never ever catch a cancellation without re-raising. In basic `asyncio`, that might be ok, but in structured concurrency, it never is. \ No newline at end of file diff --git a/pylabrobot/arms/precise_flex/precise_flex_backend.py b/pylabrobot/arms/precise_flex/precise_flex_backend.py index b7acab98566..032e92ad6e3 100644 --- a/pylabrobot/arms/precise_flex/precise_flex_backend.py +++ b/pylabrobot/arms/precise_flex/precise_flex_backend.py @@ -1,8 +1,9 @@ -import asyncio import warnings from abc import ABC from typing import Dict, List, Literal, Optional, Union +import anyio + from pylabrobot.arms.backend import ( AccessPattern, HorizontalAccess, @@ -12,6 +13,7 @@ from pylabrobot.arms.precise_flex.coords import ElbowOrientation, PreciseFlexCartesianCoords from pylabrobot.arms.precise_flex.error_codes import ERROR_CODES from pylabrobot.arms.precise_flex.joints import PFAxis +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.socket import Socket from pylabrobot.resources import Coordinate, Rotation @@ -54,6 +56,7 @@ def __init__( self.timeout = timeout self._has_rail = has_rail self._is_dual_gripper = is_dual_gripper + if is_dual_gripper: warnings.warn( "Dual gripper support is experimental and may not work as expected.", UserWarning @@ -90,21 +93,21 @@ def _convert_to_cartesian_array( ) return arr - async def setup(self, skip_home: bool = False): - """Initialize the PreciseFlex backend.""" - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, skip_home: bool = False): + await super()._enter_lifespan(stack) + + await stack.enter_async_context(self.io) + stack.push_shielded_async_callback(self.exit) + await self.set_response_mode("pc") await self.power_on_robot() await self.attach(1) if not skip_home: await self.home() - async def stop(self): - """Stop the PreciseFlex backend.""" - await self.detach() - await self.power_off_robot() - await self.exit() - await self.io.stop() + # push_async_callback executes in reverse order! + stack.push_shielded_async_callback(self.power_off_robot) + stack.push_shielded_async_callback(self.detach) async def set_speed(self, speed_percent: float): """Set the speed percentage of the arm's movement (0-100).""" @@ -1591,7 +1594,7 @@ async def wait_for_eom(self) -> None: some other means. Does not reply until the robot has stopped. """ await self.send_command("waitForEom") - await asyncio.sleep(0.2) # Small delay to ensure command is fully processed + await anyio.sleep(0.2) # Small delay to ensure command is fully processed async def zero_torque(self, enable: bool, axis_mask: int = 1) -> None: """Sets or clears zero torque mode for the selected robot. diff --git a/pylabrobot/arms/precise_flex/precise_flex_backend_tests.py b/pylabrobot/arms/precise_flex/precise_flex_backend_tests.py index 0742f7281e7..acbefed61c6 100644 --- a/pylabrobot/arms/precise_flex/precise_flex_backend_tests.py +++ b/pylabrobot/arms/precise_flex/precise_flex_backend_tests.py @@ -1,28 +1,31 @@ -import unittest from typing import Dict from unittest.mock import AsyncMock, patch +import anyio + from pylabrobot.arms.backend import HorizontalAccess, VerticalAccess from pylabrobot.arms.precise_flex.coords import ElbowOrientation, PreciseFlexCartesianCoords from pylabrobot.arms.precise_flex.joints import PFAxis from pylabrobot.arms.precise_flex.precise_flex_backend import PreciseFlexBackend, PreciseFlexError from pylabrobot.io.socket import Socket # Import Socket for mocking from pylabrobot.resources import Coordinate, Rotation +from pylabrobot.testing.concurrency import AnyioTestBase -class PreciseFlexBackendHardwareTests(unittest.IsolatedAsyncioTestCase): +class TestPreciseFlexBackendHardware(AnyioTestBase): """Integration tests for PreciseFlex robot - RUNS ON ACTUAL HARDWARE""" -class PreciseFlexBackendTests(unittest.IsolatedAsyncioTestCase): +class TestPreciseFlexBackend(AnyioTestBase): """Unit tests for PreciseFlexBackend""" - def setUp(self): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) + self.mock_socket_instance = AsyncMock(spec=Socket) self.mock_socket_instance.read.return_value = b"" self.mock_socket_instance.readline.return_value = b"" self.mock_socket_instance.write.return_value = None - self.mock_socket_instance.setup.return_value = None # Configure setup to return None self.mock_socket_instance._writer = AsyncMock() # Mock the _writer attribute # Patch the Socket class where it's used in PreciseFlexBackend @@ -31,9 +34,10 @@ def setUp(self): return_value=self.mock_socket_instance, ) self.MockSocketClass = patcher.start() # Store the mock of the class - self.addCleanup(patcher.stop) + stack.push_async_callback(lambda: anyio.to_thread.run_sync(patcher.stop)) self.backend = PreciseFlexBackend(has_rail=False, host="localhost", port=10100) + # self.backend.io is already self.mock_socket_instance because of the patch async def test_init(self): @@ -92,9 +96,14 @@ async def test_setup(self): b"0 OK\r\n", # power_on_robot b"0 OK\r\n", # attach b"0 OK\r\n", # home + b"0 attach\r\n", # detach + b"0 hp\r\n", # power_off_robot + b"0 exit\r\n", # exit ] - await self.backend.setup() - self.mock_socket_instance.setup.assert_called_once() + async with self.backend: + pass + self.mock_socket_instance.__aenter__.assert_called_once() + self.mock_socket_instance.write.assert_any_call(b"mode 0\n") self.mock_socket_instance.write.assert_any_call(b"hp 1 20\n") self.mock_socket_instance.write.assert_any_call(b"attach 1\n") @@ -102,15 +111,20 @@ async def test_setup(self): async def test_stop(self): self.mock_socket_instance.readline.side_effect = [ + b"0 OK\r\n", # set_mode + b"0 OK\r\n", # power_on_robot + b"0 OK\r\n", # attach + b"0 OK\r\n", # home b"0 attach\r\n", # detach b"0 hp\r\n", # power_off_robot b"0 exit\r\n", # exit ] - await self.backend.stop() + async with self.backend: + pass self.mock_socket_instance.write.assert_any_call(b"attach 0\n") self.mock_socket_instance.write.assert_any_call(b"hp 0\n") self.mock_socket_instance.write.assert_any_call(b"exit\n") - self.mock_socket_instance.stop.assert_called_once() + self.mock_socket_instance.__aexit__.assert_called_once() async def test_set_speed(self): self.mock_socket_instance.readline.return_value = b"0 Speed 1 50.0\r\n" diff --git a/pylabrobot/arms/scara_tests.py b/pylabrobot/arms/scara_tests.py index 4455abc104d..5360a4b0b1e 100644 --- a/pylabrobot/arms/scara_tests.py +++ b/pylabrobot/arms/scara_tests.py @@ -1,14 +1,14 @@ -import unittest from unittest.mock import AsyncMock, MagicMock from pylabrobot.arms.backend import SCARABackend from pylabrobot.arms.precise_flex.coords import PreciseFlexCartesianCoords from pylabrobot.arms.scara import ExperimentalSCARA from pylabrobot.resources import Coordinate, Rotation +from pylabrobot.testing.concurrency import AnyioTestBase -class TestExperimentalSCARA(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): +class TestExperimentalSCARA(AnyioTestBase): + async def _enter_lifespan(self, stack): self.mock_backend = MagicMock(spec=SCARABackend) for method_name in [ "move_to", diff --git a/pylabrobot/barcode_scanners/keyence/keyence_backend.py b/pylabrobot/barcode_scanners/keyence/keyence_backend.py index e79501fda71..3fee0ae1258 100644 --- a/pylabrobot/barcode_scanners/keyence/keyence_backend.py +++ b/pylabrobot/barcode_scanners/keyence/keyence_backend.py @@ -1,6 +1,6 @@ -import asyncio import logging -import time + +import anyio try: import serial @@ -14,6 +14,7 @@ BarcodeScannerBackend, BarcodeScannerError, ) +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.resources.barcode import Barcode @@ -51,26 +52,27 @@ def __init__( rtscts=False, ) - async def setup(self): - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) await self.initialize() async def initialize(self): """Initialize the Keyence barcode scanner.""" - - deadline = time.time() + self.init_timeout - while time.time() < deadline: - response = await self.send_command("RMOTOR") - if response.strip() == "MOTORON": - logger.info("Barcode scanner motor is ON.") - break - elif response.strip() == "MOTOROFF": - raise BarcodeScannerError("Failed to initialize Keyence barcode scanner: Motor is off.") - await asyncio.sleep(self.poll_interval) - else: + try: + with anyio.fail_after(self.init_timeout): + while True: + response = await self.send_command("RMOTOR") + if response.strip() == "MOTORON": + logger.info("Barcode scanner motor is ON.") + break + elif response.strip() == "MOTOROFF": + raise BarcodeScannerError("Failed to initialize Keyence barcode scanner: Motor is off.") + await anyio.sleep(self.poll_interval) + except TimeoutError as e: raise BarcodeScannerError( "Failed to initialize Keyence barcode scanner: Timeout waiting for motor to turn on." - ) + ) from e async def send_command(self, command: str) -> str: """Send a command to the barcode scanner and return the response. @@ -80,9 +82,6 @@ async def send_command(self, command: str) -> str: response = await self.io.read() return response.decode(self.serial_messaging_encoding).strip() - async def stop(self): - await self.io.stop() - async def scan_barcode(self) -> Barcode: data = await self.send_command("LON") if data.startswith("NG"): diff --git a/pylabrobot/centrifuge/centrifuge_tests.py b/pylabrobot/centrifuge/centrifuge_tests.py index 9dbf6c56d8b..26c8d0d6421 100644 --- a/pylabrobot/centrifuge/centrifuge_tests.py +++ b/pylabrobot/centrifuge/centrifuge_tests.py @@ -12,6 +12,7 @@ from pylabrobot.centrifuge.backend import CentrifugeBackend, LoaderBackend from pylabrobot.centrifuge.chatterbox import CentrifugeChatterboxBackend, LoaderChatterboxBackend from pylabrobot.resources import Coordinate, Cor_96_wellplate_360ul_Fb +from pylabrobot.testing.concurrency import AnyioTestBase class CentrifugeTests(unittest.IsolatedAsyncioTestCase): @@ -28,8 +29,8 @@ def test_serialization(self): self.assertEqual(deserialized, centrifuge) -class CentrifugeLoaderResourceModelTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): +class CentrifugeLoaderResourceModelTests(AnyioTestBase): + async def _enter_lifespan(self, stack): self.mock_centrifuge_backend = unittest.mock.MagicMock(spec=CentrifugeBackend) self.mock_loader_backend = unittest.mock.MagicMock(spec=LoaderBackend) self.centrifuge = Centrifuge( @@ -45,7 +46,6 @@ async def asyncSetUp(self): child_location=Coordinate.zero(), ) self.plate = Cor_96_wellplate_360ul_Fb(name="plate") - return await super().asyncSetUp() async def test_go_to_bucket(self): self.assertIsNone(self.centrifuge.at_bucket) diff --git a/pylabrobot/centrifuge/chatterbox.py b/pylabrobot/centrifuge/chatterbox.py index 4f32d678473..6aa42db0441 100644 --- a/pylabrobot/centrifuge/chatterbox.py +++ b/pylabrobot/centrifuge/chatterbox.py @@ -1,12 +1,12 @@ from pylabrobot.centrifuge.backend import CentrifugeBackend, LoaderBackend +from pylabrobot.concurrency import AsyncExitStackWithShielding class CentrifugeChatterboxBackend(CentrifugeBackend): - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up") - - async def stop(self): - print("Stopping") + stack.callback(lambda: print("Stopping")) async def open_door(self): print("Opening door") @@ -40,11 +40,10 @@ async def spin(self, g: float, duration: float, acceleration: float): class LoaderChatterboxBackend(LoaderBackend): - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up") - - async def stop(self): - print("Stopping") + stack.callback(lambda: print("Stopping")) async def load(self): print("Loading") diff --git a/pylabrobot/centrifuge/vspin_backend.py b/pylabrobot/centrifuge/vspin_backend.py index 021a550f359..89d499687bc 100644 --- a/pylabrobot/centrifuge/vspin_backend.py +++ b/pylabrobot/centrifuge/vspin_backend.py @@ -1,13 +1,14 @@ -import asyncio import ctypes import json import logging import math import os -import time import warnings from typing import Optional +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.ftdi import FTDI from .backend import CentrifugeBackend, LoaderBackend @@ -33,14 +34,16 @@ def __init__( async def _read(self) -> bytes: x = b"" r = None - start = time.time() - while r != b"" or x == b"": - r = await self.io.read(1) - x += r - if r == b"": - await asyncio.sleep(0.1) - if x == b"" and (time.time() - start) > self.timeout: - raise TimeoutError("No data received within the specified timeout period") + with anyio.move_on_after(self.timeout) as scope: + while r != b"" or x == b"": + r = await self.io.read(1) + x += r + if r == b"": + await anyio.sleep(0.1) + if x != b"": + scope.deadline = float("inf") + if x == b"" and scope.cancel_called: + raise TimeoutError("No data received within the specified timeout period") return x async def send_command(self, command: bytes) -> bytes: @@ -48,10 +51,11 @@ async def send_command(self, command: bytes) -> bytes: await self.io.write(command) return await self._read() - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) logger.debug("[loader] setup") - await self.io.setup() + await stack.enter_async_context(self.io) await self.io.set_baudrate(115384) status = await self.get_status() @@ -71,10 +75,6 @@ async def setup(self): await self.send_command(bytes.fromhex("1105000e00440b00000000000000007041020203c7")) # await self.send_command(bytes.fromhex("11050003002000006bd4")) - async def stop(self): - logger.debug("[loader] stop") - await self.io.stop() - def serialize(self): return {"io": self.io.serialize(), "timeout": self.timeout} @@ -187,8 +187,15 @@ def __init__(self, device_id: Optional[str] = None): if device_id is not None: self._bucket_1_remainder = _load_vspin_calibrations(device_id) - async def setup(self): - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) + + async def _cleanup(): + await self.configure_and_initialize() + + stack.push_shielded_async_callback(_cleanup) + # TODO: add functionality where if robot has been initialized before nothing needs to happen for _ in range(3): await self.configure_and_initialize() @@ -298,10 +305,6 @@ async def get_bucket_1_position(self) -> int: ) return bucket_1_position - async def stop(self): - await self.configure_and_initialize() - await self.io.stop() - class _StatusPositionTachometer(ctypes.LittleEndianStructure): _pack_ = 1 _fields_ = [ @@ -385,19 +388,22 @@ async def _read_resp(self, timeout: float = 20) -> bytes: been read so far.""" data = b"" end_byte_found = False - start_time = time.time() - - while True: - chunk = await self.io.read(25) - if chunk: - data += chunk - end_byte_found = data[-1] == 0x0D - if len(chunk) < 25 and end_byte_found: - break - else: - if end_byte_found or time.time() - start_time > timeout: - break - await asyncio.sleep(0.0001) + + with anyio.move_on_after(timeout) as scope: + while True: + chunk = await self.io.read(25) + if chunk: + data += chunk + end_byte_found = data[-1] == 0x0D + if len(chunk) < 25 and end_byte_found: + break + else: + if end_byte_found: + break + await anyio.sleep(0.0001) + + if scope.cancel_called: + logger.warning("timed out reading response") logger.debug("Read %s", data.hex()) return data @@ -436,7 +442,7 @@ async def open_door(self): await self._send_command(bytes.fromhex("aa022600062e")) # same as unlock door # we can't tell when the door is fully open, so we just wait a bit - await asyncio.sleep(4) + await anyio.sleep(4) async def close_door(self): if not (await self.get_door_open()): @@ -444,7 +450,7 @@ async def close_door(self): # used to be: aa022600052d await self._send_command(bytes.fromhex("aa022600042c")) # same as unlock door # we can't tell when the door is fully closed, so we just wait a bit - await asyncio.sleep(2) + await anyio.sleep(2) async def lock_door(self): if await self.get_door_open(): @@ -497,7 +503,7 @@ async def go_to_position(self, position: int): while ( abs(await self.get_position() - position) > 10 ): # 10 tacks tolerance (10/8000 * 360 = 0.45 degrees) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await self.open_door() @staticmethod @@ -586,7 +592,7 @@ async def spin( # 3 - wait for acceleration to the set rpm # we also check the position to avoid waiting forever if the speed is not reached (e.g. short spin...) while await self.get_tachometer() < rpm * 0.95 and await self.get_position() < final_position: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) # 4 - once the speed is reached, compute the position at which to start deceleration # this is different than computed above, because above we assumed constant acceleration from 0 to rpm. @@ -598,7 +604,7 @@ async def spin( # then wait until we reach that position while await self.get_position() < decel_start_position: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) # 5 - send deceleration command await self._send_command(bytes.fromhex("aa01e60500640000000000fd00803e01000c")) @@ -610,7 +616,7 @@ async def spin( decel_command += ((sum(decel_command) - 0xAA) & 0xFF).to_bytes(1, byteorder="little") await self._send_command(decel_command) - await asyncio.sleep(2) + await anyio.sleep(2) # 6 - reset position back to 0ish # this part is aneeded because otherwise calling go_to_position will not work after @@ -632,7 +638,7 @@ async def _reset_to_zero(): start = await self.get_home_position() num_tries = 0 while await self.get_home_position() == start: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) num_tries += 1 if num_tries % 25 == 0: await _reset_to_zero() diff --git a/pylabrobot/concurrency.py b/pylabrobot/concurrency.py new file mode 100644 index 00000000000..8bc51141d0b --- /dev/null +++ b/pylabrobot/concurrency.py @@ -0,0 +1,278 @@ +import abc +import asyncio +import contextlib +import dataclasses +import functools +import sys +import typing +import warnings + +if sys.version_info >= (3, 10): + from typing import Any, Optional, TypeAlias +else: + from typing_extensions import Any, Optional, TypeAlias + +import anyio +import sniffio + + +class MachineConnectionClosedError(Exception): + """Raised when a machine task is being aborted because the connection is, or has been closed.""" + + +class AsyncExitStackWithShielding(contextlib.AsyncExitStack): + def push_shielded_async_callback(self, callback: typing.Callable, *args): + @functools.wraps(callback) + async def shielded_callback(*args): + with anyio.CancelScope(shield=True): + await callback(*args) + + self.push_async_callback(shielded_callback, *args) + + +@dataclasses.dataclass(frozen=True) +class _LifespanLifecycleTag: + """Tags used to represent the lifecycle of a lifespan, + for accurate double-entry checking.""" + + name: str + + +LifespanEntering = _LifespanLifecycleTag("entering") +LifespanExiting = _LifespanLifecycleTag("exiting") +AnonymousLifespan = _LifespanLifecycleTag("anonymous") + + +class _AsyncResourceBase: + """Implementation of `AsyncResource`, but without any `__new__` to implement ABC checking.""" + + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + """Helper for the _lifespan implementation; override this instead of _lifespan. + + Note, child classes may add keyword-only arguments to the signature, as _lifespan + forwards those. + """ + raise NotImplementedError("Subclasses must override _enter_lifespan or _lifespan.") + + @contextlib.asynccontextmanager + async def _lifespan(self, **kwargs): + """The resource's lifespan. + + Subclasses should override this method to provide their own lifespan. + Alternatively, they can provide `_enter_lifespan(stack)` which gets called with an `AsyncExitStack`. + """ + # double-entry checking, using _active_lifespan as signalling mechanism. + # this double-entry checking here isn't strictly necessary, since usually, + # we always enter through __aenter__. + active_lifespan = getattr(self, "_active_lifespan", None) + if active_lifespan is None: + # This is a direct call to _lifespan, not going through __aenter__. + # we don't have access to the context manager, so we just store a tag. + self._active_lifespan = AnonymousLifespan + elif active_lifespan is not LifespanEntering: + raise RuntimeError(f"lifespan of {type(self).__name__} is already entered") + + # main implementation + try: + async with AsyncExitStackWithShielding() as stack: + await self._enter_lifespan(stack, **kwargs) + yield self + # there shouldn't be anything here; explicit cleanup is difficult to get right + # in face of exceptions and cancellation; register your cleanup when you enter. + finally: + if self._active_lifespan is AnonymousLifespan: + self._active_lifespan = None # type: ignore[assignment] + + async def __aenter__(self): + """Enter the resource's lifespan. + This method should not be overridden by subclasses; + separate `__aenter__` and `__aexit__` calls are difficult to implement correctly, + implement `_lifespan` or `_enter_lifespan` instead. + """ + if getattr(self, "_active_lifespan", None) is not None: + raise RuntimeError(f"lifespan of {type(self).__name__} is already entered") + + try: + self._active_lifespan = LifespanEntering + active_lifespan = self._lifespan() + await active_lifespan.__aenter__() + self._active_lifespan = active_lifespan # type: ignore[assignment] + except: + self._active_lifespan = None # type: ignore[assignment] + raise + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the resource's context. + This method should never be overridden. + """ + try: + active_lifespan = self._active_lifespan + self._active_lifespan = LifespanExiting + ret = await active_lifespan.__aexit__(exc_type, exc_val, exc_tb) # type: ignore[attr-defined] + finally: + self._active_lifespan = None # type: ignore[assignment] + return ret + + +class AsyncResource(_AsyncResourceBase, abc.ABC): + """An abstract base class for all resources.""" + + def __new__(cls, *args, **kwargs): + # Check if both methods are still the base implementations + if ( + cls._enter_lifespan is AsyncResource._enter_lifespan + and cls._lifespan is _AsyncResourceBase._lifespan + ): + raise TypeError( + f"Can't instantiate abstract class {cls.__name__} " + "without an implementation for either '_enter_lifespan' or '_lifespan'" + ) + + return super().__new__(cls) + + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + # Non-throwing base class implementation, so that derived classes can + # call super()._enter_lifespan() without knowing how many classes are in the chain. + pass + + +MachineID: TypeAlias = Any + + +class GlobalManager: + """A global task manager to enable interactive (notebook) usage of async context managers.""" + + def __init__(self): + self._tg: Optional[anyio.abc.TaskGroup] = None + self._running_task: Optional[asyncio.Task] = None + self._started: Optional[anyio.Event] = None + self._stop: Optional[anyio.Event] = None + self._pending: set[MachineID] = set() + self._stop_events: dict[MachineID, anyio.Event] = {} + self._exit_events: dict[MachineID, anyio.Event] = {} + self._errors: dict[MachineID, Exception] = {} + + async def _run_global_task_group(self): + async with anyio.create_task_group() as tg: + assert self._tg is None + self._tg = tg + self._stop = anyio.Event() + assert self._started is not None + self._started.set() + await self._stop.wait() + + @contextlib.asynccontextmanager + async def _reserve_runner_for(self, obj): + try: + backend = sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + backend = "asyncio" + + if backend != "asyncio": + raise RuntimeError( + f"The global manager for interactive setup/stop is currently only supported " + f"on asyncio (Jupyter). Caught: {backend}. Please use `async with machine:` directly." + ) + + loop = asyncio.get_running_loop() + + try: + self._pending.add(obj) + if self._tg is None: + try: + self._started = anyio.Event() + self._running_task = loop.create_task(self._run_global_task_group()) + await self._started.wait() + finally: + self._started = None + yield self._tg + finally: + self._pending.discard(obj) + + async def manage_context(self, obj: Any): + """Schedules an object's async context manager into the global task group.""" + + stop_event = self._stop_events.get(obj) + if stop_event is not None: + warnings.warn(f"Object {obj} is already managed by the global task group.") + return + warnings.warn( + "Prefer using structured concurrency (`async with resource:`) over `.setup` calls.", + DeprecationWarning, + ) + + async def wrapper(*, task_status=anyio.TASK_STATUS_IGNORED): + try: + print("entering obj context manager") + async with obj: + print("entered obj context manager") + task_status.started() + assert stop_event is not None + await stop_event.wait() + except Exception as e: + self._errors[obj] = e + finally: + self._stop_events.pop(obj, None) + exit_event = self._exit_events.get(obj) + if exit_event is not None: + exit_event.set() + if not self._pending and not self._stop_events: + assert self._stop is not None + self._stop.set() + self._stop = None + self._tg = None + self._running_task = None + + try: + async with self._reserve_runner_for(obj): + self._stop_events[obj] = stop_event = anyio.Event() + assert self._tg is not None + await self._tg.start(wrapper) + except Exception: + self._stop_events.pop(obj, None) + raise + finally: + e = self._errors.pop(obj, None) + if e is not None: + raise e + + async def release_context(self, obj: Any): + """Signals the given object's context manager to gracefully exit.""" + errors = self._errors.pop(obj, None) + if errors is not None: + raise errors + + stop_event = self._stop_events.pop(obj, None) + + if stop_event is None: + warnings.warn(f"Object {obj} is not managed by the global task group. ") + return + + try: + self._exit_events[obj] = exit_event = anyio.Event() + stop_event.set() + await exit_event.wait() + finally: + self._exit_events.pop(obj, None) + + async def stop_all(self): + """Forcefully stops all managed objects and terminates the global TaskGroup.""" + + async def do_release(obj, go, *, task_status=anyio.TASK_STATUS_IGNORED): + with anyio.CancelScope(shield=True): + task_status.started() + await go.wait() + await self.release_context(obj) + + # Release all managed objects simultaneously in a task group. + # Each release is shielded from cancellation; this guarantees that + # all objects attempt to exit, and we get all errors in one ExceptionGroup. + async with anyio.create_task_group() as tg: + go = anyio.Event() + for obj in list(self._stop_events.keys()): + await tg.start(do_release, obj, go) + go.set() + + +global_manager = GlobalManager() diff --git a/pylabrobot/heating_shaking/bioshake_backend.py b/pylabrobot/heating_shaking/bioshake_backend.py index 2816363b5a2..c94a78b5333 100644 --- a/pylabrobot/heating_shaking/bioshake_backend.py +++ b/pylabrobot/heating_shaking/bioshake_backend.py @@ -1,9 +1,10 @@ -import asyncio import warnings +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.heating_shaking.backend import HeaterShakerBackend from pylabrobot.io.serial import Serial -from pylabrobot.machines.backend import MachineBackend try: import serial @@ -43,13 +44,14 @@ async def _send_command(self, cmd: str, delay: float = 0.5, timeout: float = 2): # Send the command await self.io.write((cmd + "\r").encode("ascii")) - await asyncio.sleep(delay) + await anyio.sleep(delay) # Read and decode the response with a timeout try: - response = await asyncio.wait_for(self.io.readline(), timeout=timeout) + with anyio.fail_after(timeout): + response = await self.io.readline() - except asyncio.TimeoutError: + except TimeoutError: raise RuntimeError(f"Timed out waiting for response to '{cmd}'") decoded = response.decode("ascii", errors="ignore").strip() @@ -77,21 +79,17 @@ async def _send_command(self, cmd: str, delay: float = 0.5, timeout: float = 2): except Exception as e: raise RuntimeError(f"Unexpected error while sending '{cmd}': {type(e).__name__}: {e}") from e - async def setup(self, skip_home: bool = False): - await MachineBackend.setup(self) - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, skip_home: bool = False): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) if not skip_home: # Reset first before homing it to ensure the device is ready for run await self.reset() # Additional seconds until next command can be send after reset - await asyncio.sleep(4) + await anyio.sleep(4) # Now home the device await self.home() - async def stop(self): - await MachineBackend.stop(self) - await self.io.stop() - async def reset(self): # Reset the BioShake if stuck in "e" state # Flush serial buffers for a clean start @@ -101,28 +99,18 @@ async def reset(self): # Send the command await self.io.write(("resetDevice\r").encode("ascii")) - start = asyncio.get_event_loop().time() - max_seconds = 30 # How long a reset typically last - - while True: - # Break the loop if process takes longer than 30 seconds - if asyncio.get_event_loop().time() - start > max_seconds: - raise TimeoutError("Reset did not complete in time") - - try: - # Wait for each line with a timeout - response = await asyncio.wait_for(self.io.readline(), timeout=2) - decoded = response.decode("ascii", errors="ignore").strip() - await asyncio.sleep(0.1) - - if len(decoded) > 0: - # Stop when the final message arrives - if "Initialization complete" in decoded: - break - - except asyncio.TimeoutError: - # Keep polling if nothing arrives within timeout - continue + try: + with anyio.fail_after(30): + while True: + response = await self.io.readline() + decoded = response.decode("ascii", errors="ignore").strip() + await anyio.sleep(0.1) + if len(decoded) > 0: + # Stop when the final message arrives + if "Initialization complete" in decoded: + break + except TimeoutError: + raise TimeoutError("Reset did not complete in time") from None async def home(self): # Initialize the BioShake into home position @@ -216,7 +204,7 @@ async def stop_shaking(self, deceleration: int = 0): # before the edge-locking mechanism (ELM) can operate. Without this # delay, subsequent setElmUnlockPos commands return 'e' (error). sleep_time_after_stop = 3 - await asyncio.sleep(sleep_time_after_stop) + await anyio.sleep(sleep_time_after_stop) @property def supports_locking(self) -> bool: diff --git a/pylabrobot/heating_shaking/hamilton_backend.py b/pylabrobot/heating_shaking/hamilton_backend.py index f302502115a..0e210a65e86 100644 --- a/pylabrobot/heating_shaking/hamilton_backend.py +++ b/pylabrobot/heating_shaking/hamilton_backend.py @@ -1,9 +1,12 @@ import abc -import time +import contextlib import warnings from enum import Enum from typing import Dict, Literal, Optional +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.heating_shaking.backend import HeaterShakerBackend from pylabrobot.io.usb import USB @@ -29,6 +32,9 @@ def __init__( device_address: Optional[int] = None, serial_number: Optional[str] = None, ): + """ + If io fails to connect, ensure that libusb drivers were installed for the HHS as per docs. + """ self.io = USB( human_readable_device_name="Hamilton Heater Shaker Box", id_vendor=id_vendor, @@ -43,15 +49,6 @@ def _generate_id(self) -> int: self._id += 1 return self._id % 10000 - async def setup(self): - """ - If io.setup() fails, ensure that libusb drivers were installed for the HHS as per docs. - """ - await self.io.setup() - - async def stop(self): - await self.io.stop() - async def send_hhs_command(self, index: int, command: str, **kwargs) -> str: args = "".join([f"{key}{value}" for key, value in kwargs.items()]) id_ = str(self._generate_id()).zfill(4) @@ -77,16 +74,11 @@ def __init__(self, index: int, interface: HamiltonHeaterShakerInterface) -> None super().__init__() self.interface = interface - async def setup(self): - """ - If io.setup() fails, ensure that libusb drivers were installed for the HHS as per docs. - """ + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) await self._initialize_lock() await self._initialize_shaker_drive() - async def stop(self): - pass - def serialize(self) -> dict: warnings.warn("The interface is not serialized.") @@ -119,13 +111,14 @@ async def start_shaking( assert direction in [0, 1], "Direction must be 0 or 1" assert 500 <= acceleration <= 10_000, "Acceleration must be between 500 and 10_000" - now = time.time() - while True: - await self._start_shaking(direction=direction, speed=int_speed, acceleration=acceleration) - if await self.get_is_shaking(): - break - if timeout is not None and time.time() - now > timeout: - raise TimeoutError("Failed to start shaking within timeout") + async with contextlib.AsyncExitStack() as stack: + if timeout is not None: + stack.enter_context(anyio.fail_after(timeout)) + while True: + await self._start_shaking(direction=direction, speed=int_speed, acceleration=acceleration) + if await self.get_is_shaking(): + break + await anyio.sleep(0.1) async def shake( self, diff --git a/pylabrobot/heating_shaking/heater_shaker.py b/pylabrobot/heating_shaking/heater_shaker.py index 39437de1ee7..4b798a5417d 100644 --- a/pylabrobot/heating_shaking/heater_shaker.py +++ b/pylabrobot/heating_shaking/heater_shaker.py @@ -1,6 +1,6 @@ from typing import Optional -from pylabrobot.machines.machine import Machine +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.resources.coordinate import Coordinate from pylabrobot.shaking import Shaker from pylabrobot.temperature_controlling import TemperatureController @@ -34,7 +34,11 @@ def __init__( ) self.backend: HeaterShakerBackend = backend # fix type - async def stop(self): - await self.deactivate() - await self.stop_shaking() - await Machine.stop(self) + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + + async def cleanup(): + await self.deactivate() + await self.stop_shaking() + + stack.push_shielded_async_callback(cleanup) diff --git a/pylabrobot/heating_shaking/inheco/thermoshake_backend.py b/pylabrobot/heating_shaking/inheco/thermoshake_backend.py index 76efc556b36..a696373b9f8 100644 --- a/pylabrobot/heating_shaking/inheco/thermoshake_backend.py +++ b/pylabrobot/heating_shaking/inheco/thermoshake_backend.py @@ -1,5 +1,6 @@ import warnings +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.heating_shaking.backend import HeaterShakerBackend from pylabrobot.temperature_controlling.inheco.temperature_controller import ( InhecoTemperatureControllerBackend, @@ -12,9 +13,9 @@ class InhecoThermoshakeBackend(InhecoTemperatureControllerBackend, HeaterShakerB https://www.inheco.com/thermoshake-ac.html """ - async def stop(self): - await self.stop_shaking() - await super().stop() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + stack.push_shielded_async_callback(self.stop_shaking) async def _start_shaking_command(self): """Send the device command that starts shaking with the configured settings.""" diff --git a/pylabrobot/io/ftdi.py b/pylabrobot/io/ftdi.py index e33b48e1b54..75f7ca57df7 100644 --- a/pylabrobot/io/ftdi.py +++ b/pylabrobot/io/ftdi.py @@ -1,10 +1,11 @@ -import asyncio import ctypes import logging -from concurrent.futures import ThreadPoolExecutor -from io import IOBase from typing import Optional, cast +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding + try: import pylibftdi.driver from pylibftdi import Device, FtdiError @@ -25,6 +26,7 @@ from pylabrobot.io.capture import CaptureReader, Command, capturer, get_capture_or_validation_active from pylabrobot.io.errors import ValidationError +from pylabrobot.io.io import IOBase from pylabrobot.io.validation_utils import LOG_LEVEL_IO, align_sequences logger = logging.getLogger(__name__) @@ -82,7 +84,7 @@ def __init__( # Will be resolved in setup() self._dev: Optional[Device] = None - self._executor: Optional[ThreadPoolExecutor] = None + self._lock = anyio.Lock() if get_capture_or_validation_active(): raise RuntimeError( @@ -178,13 +180,13 @@ def _resolve_device_serial(self) -> str: device_serial_number = cast(str, usb.util.get_string(device, device.iSerialNumber)) return device_serial_number - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """Initialize the FTDI device connection with device resolution.""" if self._dev is not None and not self._dev.closed: - self._dev.close() + await anyio.to_thread.run_sync(self._dev.close) try: # Resolve which device to connect to - self._device_id = self._resolve_device_serial() + self._device_id = await anyio.to_thread.run_sync(self._resolve_device_serial) # Create and open device self._dev = Device( @@ -194,7 +196,7 @@ async def setup(self): vid=self._vid, interface_select=self._interface_select, ) - self._dev.open() + await anyio.to_thread.run_sync(self._dev.open) logger.info(f"Successfully opened FTDI device: {self.device_id}") except FtdiError as e: raise RuntimeError( @@ -203,7 +205,12 @@ async def setup(self): "Try restarting the kernel." ) from e - self._executor = ThreadPoolExecutor(max_workers=1) + async def _cleanup(): + if self._dev is not None: + await anyio.to_thread.run_sync(self._dev.close) + self._dev = None + + stack.push_shielded_async_callback(_cleanup) @property def device_id(self) -> str: @@ -211,47 +218,41 @@ def device_id(self) -> str: raise RuntimeError("Device not initialized. Call setup() first.") return self._device_id + async def _dev_call(self, func, *args): + async with self._lock: + return await anyio.to_thread.run_sync(func, *args) + async def set_baudrate(self, baudrate: int): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, lambda: setattr(self.dev, "baudrate", baudrate)) + await self._dev_call(setattr, self.dev, "baudrate", baudrate) logger.log(LOG_LEVEL_IO, "[%s] set_baudrate %s", self._device_id, baudrate) capturer.record( FTDICommand(device_id=self.device_id, action="set_baudrate", data=str(baudrate)) ) async def set_rts(self, level: bool): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, lambda: self.dev.ftdi_fn.ftdi_setrts(level)) + await self._dev_call(self.dev.ftdi_fn.ftdi_setrts, level) logger.log(LOG_LEVEL_IO, "[%s] set_rts %s", self._device_id, level) capturer.record(FTDICommand(device_id=self.device_id, action="set_rts", data=str(level))) async def set_dtr(self, level: bool): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, lambda: self.dev.ftdi_fn.ftdi_setdtr(level)) + await self._dev_call(self.dev.ftdi_fn.ftdi_setdtr, level) logger.log(LOG_LEVEL_IO, "[%s] set_dtr %s", self._device_id, level) capturer.record(FTDICommand(device_id=self.device_id, action="set_dtr", data=str(level))) async def usb_reset(self): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, lambda: self.dev.ftdi_fn.ftdi_usb_reset()) + await self._dev_call(self.dev.ftdi_fn.ftdi_usb_reset) logger.log(LOG_LEVEL_IO, "[%s] usb_reset", self._device_id) capturer.record(FTDICommand(device_id=self.device_id, action="usb_reset", data="")) async def set_latency_timer(self, latency: int): - loop = asyncio.get_running_loop() - await loop.run_in_executor( - self._executor, lambda: self.dev.ftdi_fn.ftdi_set_latency_timer(latency) - ) + await self._dev_call(self.dev.ftdi_fn.ftdi_set_latency_timer, latency) logger.log(LOG_LEVEL_IO, "[%s] set_latency_timer %s", self._device_id, latency) capturer.record( FTDICommand(device_id=self.device_id, action="set_latency_timer", data=str(latency)) ) async def set_line_property(self, bits: int, stopbits: int, parity: int): - loop = asyncio.get_running_loop() - await loop.run_in_executor( - self._executor, lambda: self.dev.ftdi_fn.ftdi_set_line_property(bits, stopbits, parity) - ) + await self._dev_call(self.dev.ftdi_fn.ftdi_set_line_property, bits, stopbits, parity) logger.log( LOG_LEVEL_IO, "[%s] set_line_property %s,%s,%s", self._device_id, bits, stopbits, parity ) @@ -262,31 +263,25 @@ async def set_line_property(self, bits: int, stopbits: int, parity: int): ) async def set_flowctrl(self, flowctrl: int): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, lambda: self.dev.ftdi_fn.ftdi_setflowctrl(flowctrl)) + await self._dev_call(self.dev.ftdi_fn.ftdi_setflowctrl, flowctrl) logger.log(LOG_LEVEL_IO, "[%s] set_flowctrl %s", self._device_id, flowctrl) capturer.record( FTDICommand(device_id=self.device_id, action="set_flowctrl", data=str(flowctrl)) ) async def usb_purge_rx_buffer(self): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, lambda: self.dev.ftdi_fn.ftdi_usb_purge_rx_buffer()) + await self._dev_call(self.dev.ftdi_fn.ftdi_usb_purge_rx_buffer) logger.log(LOG_LEVEL_IO, "[%s] usb_purge_rx_buffer", self._device_id) capturer.record(FTDICommand(device_id=self.device_id, action="usb_purge_rx_buffer", data="")) async def usb_purge_tx_buffer(self): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, lambda: self.dev.ftdi_fn.ftdi_usb_purge_tx_buffer()) + await self._dev_call(self.dev.ftdi_fn.ftdi_usb_purge_tx_buffer) logger.log(LOG_LEVEL_IO, "[%s] usb_purge_tx_buffer", self._device_id) capturer.record(FTDICommand(device_id=self.device_id, action="usb_purge_tx_buffer", data="")) async def poll_modem_status(self) -> int: - loop = asyncio.get_running_loop() stat = ctypes.c_ushort(0) - await loop.run_in_executor( - self._executor, lambda: self.dev.ftdi_fn.ftdi_poll_modem_status(ctypes.byref(stat)) - ) + await self._dev_call(self.dev.ftdi_fn.ftdi_poll_modem_status, ctypes.byref(stat)) logger.log(LOG_LEVEL_IO, "[%s] poll_modem_status %s", self._device_id, stat.value) capturer.record( FTDICommand(device_id=self.device_id, action="poll_modem_status", data=str(stat.value)) @@ -296,21 +291,14 @@ async def poll_modem_status(self) -> int: async def get_serial(self) -> str: return self.device_id - async def stop(self): - if self._dev is not None: - self.dev.close() - if self._executor is not None: - self._executor.shutdown(wait=True) - self._executor = None - async def write(self, data: bytes) -> int: """Write data to the device. Returns the number of bytes written.""" logger.log(LOG_LEVEL_IO, "[%s] write %s", self._device_id, data) capturer.record(FTDICommand(device_id=self.device_id, action="write", data=data.hex())) - return cast(int, self.dev.write(data)) + return cast(int, await self._dev_call(self.dev.write, data)) async def read(self, num_bytes: int = 1) -> bytes: - data = self.dev.read(num_bytes) + data = await self._dev_call(self.dev.read, num_bytes) logger.log(LOG_LEVEL_IO, "[%s] read %s", self._device_id, data) capturer.record( FTDICommand( @@ -322,7 +310,7 @@ async def read(self, num_bytes: int = 1) -> bytes: return cast(bytes, data) async def readline(self) -> bytes: # type: ignore # very dumb it's reading from pyserial - data = self.dev.readline() + data = await self._dev_call(self.dev.readline) logger.log(LOG_LEVEL_IO, "[%s] readline %s", self._device_id, data) capturer.record(FTDICommand(device_id=self.device_id, action="readline", data=data.hex())) return cast(bytes, data) diff --git a/pylabrobot/io/hid.py b/pylabrobot/io/hid.py index f6b24ff3a51..037baa231f1 100644 --- a/pylabrobot/io/hid.py +++ b/pylabrobot/io/hid.py @@ -1,8 +1,10 @@ -import asyncio +import contextlib import logging -from concurrent.futures import ThreadPoolExecutor from typing import Optional, cast +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.capture import CaptureReader, Command, capturer, get_capture_or_validation_active from pylabrobot.io.errors import ValidationError from pylabrobot.io.io import IOBase @@ -38,12 +40,12 @@ def __init__( self.serial_number = serial_number self.device: Optional[hid.Device] = None self._unique_id = f"{vid}:{pid}:{serial_number}" - self._executor: Optional[ThreadPoolExecutor] = None + self._lock = anyio.Lock() if get_capture_or_validation_active(): raise RuntimeError("Cannot create a new HID object while capture or validation is active") - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """ Sets up the HID device by enumerating connected devices, matching the specified VID, PID, and optional serial number, and opening a connection to the device. @@ -55,7 +57,7 @@ async def setup(self): ) # --- 1. Enumerate all HID devices --- - all_devices = hid.enumerate() + all_devices = await anyio.to_thread.run_sync(hid.enumerate) candidates = [ d for d in all_devices @@ -99,22 +101,22 @@ async def setup(self): chosen = candidates[0] # --- 5. Open the device --- - self.device = hid.Device( - path=chosen["path"] # safer than vid/pid/serial triple - ) - self._executor = ThreadPoolExecutor(max_workers=1) + self.device = await anyio.to_thread.run_sync(lambda: hid.Device(path=chosen["path"])) + + async def _cleanup(): + if self.device is not None: + await anyio.to_thread.run_sync(self.device.close) + logger.log(LOG_LEVEL_IO, "Closing HID device %s", self._unique_id) + capturer.record(HIDCommand(device_id=self._unique_id, action="close", data="")) + + stack.push_shielded_async_callback(_cleanup) logger.log(LOG_LEVEL_IO, "Opened HID device %s", self._unique_id) capturer.record(HIDCommand(device_id=self._unique_id, action="open", data="")) - async def stop(self): - if self.device is not None: - self.device.close() - logger.log(LOG_LEVEL_IO, "Closing HID device %s", self._unique_id) - capturer.record(HIDCommand(device_id=self._unique_id, action="close", data="")) - if self._executor is not None: - self._executor.shutdown(wait=True) - self._executor = None + async def _dev_call(self, func, *args): + async with self._lock: + return await anyio.to_thread.run_sync(func, *args) async def write(self, data: bytes, report_id: bytes = b"\x00"): r"""Writes data to the HID device. @@ -139,7 +141,6 @@ async def write(self, data: bytes, report_id: bytes = b"\x00"): data: The data to write. report_id: The report ID to use for the write operation. Defaults to b'\x00'. """ - loop = asyncio.get_running_loop() write_data = report_id + data def _write(): @@ -147,9 +148,7 @@ def _write(): raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") return self.device.write(write_data) - if self._executor is None: - raise RuntimeError("Call setup() first.") - r = await loop.run_in_executor(self._executor, _write) + r = await self._dev_call(_write) logger.log( LOG_LEVEL_IO, "[%s] write %s (report_id: %s)", self._unique_id, data, report_id.hex() ) @@ -157,8 +156,6 @@ def _write(): return r async def read(self, size: int, timeout: int) -> bytes: - loop = asyncio.get_running_loop() - def _read(): if self.device is None: raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") @@ -169,9 +166,7 @@ def _read(): return b"" raise - if self._executor is None: - raise RuntimeError("Call setup() first.") - r = await loop.run_in_executor(self._executor, _read) + r = await self._dev_call(_read) if len(r) > 0: logger.log(LOG_LEVEL_IO, "[%s] read %s", self._unique_id, r) capturer.record(HIDCommand(device_id=self._unique_id, action="read", data=r.hex())) @@ -203,7 +198,7 @@ def __init__( ) self.cr = cr - async def setup(self): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): next_command = HIDCommand(**self.cr.next_command()) if ( not next_command.module == "hid" @@ -212,14 +207,16 @@ async def setup(self): ): raise ValidationError(f"Next line is {next_command}, expected HID open {self._unique_id}") - async def stop(self): - next_command = HIDCommand(**self.cr.next_command()) - if ( - not next_command.module == "hid" - and next_command.device_id == self._unique_id - and next_command.action == "close" - ): - raise ValidationError(f"Next line is {next_command}, expected HID close {self._unique_id}") + def _cleanup(): + next_command = HIDCommand(**self.cr.next_command()) + if ( + not next_command.module == "hid" + and next_command.device_id == self._unique_id + and next_command.action == "close" + ): + raise ValidationError(f"Next line is {next_command}, expected HID close {self._unique_id}") + + stack.callback(_cleanup) async def write(self, data: bytes, report_id: bytes = b"\x00"): next_command = HIDCommand(**self.cr.next_command()) diff --git a/pylabrobot/io/io.py b/pylabrobot/io/io.py index 599399a251f..c2b72cc852b 100644 --- a/pylabrobot/io/io.py +++ b/pylabrobot/io/io.py @@ -1,9 +1,13 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod +from pylabrobot.concurrency import AsyncExitStackWithShielding, AsyncResource from pylabrobot.serializer import SerializableMixin -class IOBase(SerializableMixin, ABC): +class IOBase(SerializableMixin, AsyncResource): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + pass + @abstractmethod async def write(self, data: bytes, *args, **kwargs): pass diff --git a/pylabrobot/io/serial.py b/pylabrobot/io/serial.py index 914d4567d19..1d0bad12e15 100644 --- a/pylabrobot/io/serial.py +++ b/pylabrobot/io/serial.py @@ -1,11 +1,12 @@ -import asyncio import logging -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from io import IOBase from typing import Optional, cast +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.errors import ValidationError +from pylabrobot.io.io import IOBase try: import serial @@ -58,7 +59,7 @@ def __init__( self.parity = parity self.stopbits = stopbits self._ser: Optional[serial.Serial] = None - self._executor: Optional[ThreadPoolExecutor] = None + self._lock = anyio.Lock() self.write_timeout = write_timeout self.timeout = timeout self.rtscts = rtscts @@ -76,7 +77,7 @@ def port(self) -> str: assert self._port is not None, "Port not set. Did you call setup()?" return self._port - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """ Initialize the serial connection to the device. @@ -84,7 +85,7 @@ async def setup(self): provided path or by scanning for devices matching the configured USB VID:PID pair), validates that the detected/selected port corresponds to the expected hardware, and opens the serial connection in a dedicated - threadpool executor to avoid blocking the asyncio event loop. + thread to avoid blocking the AnyIO event loop. **Behavior:** - Ensures `pyserial` is installed; otherwise raises `RuntimeError`. @@ -97,9 +98,9 @@ async def setup(self): - Verifies that it matches the specified VID/PID (when provided). - Logs the port choice for traceability. - Opens the serial port using the configured parameters - (baudrate, bytesize, parity, etc.) via `loop.run_in_executor` to + (baudrate, bytesize, parity, etc.) via `anyio.to_thread.run_sync` to ensure non-blocking operation. - - Cleans up the executor and re-raises the exception if the port cannot be opened. + - Registers a cleanup callback to close the serial port when the lifespan ends. **Raises:** RuntimeError: @@ -110,9 +111,6 @@ async def setup(self): - If an explicitly provided port does not match the VID/PID. serial.SerialException: - If the serial connection fails to open (e.g., device already in use). - - After successful completion, `self._ser` is an open `serial.Serial` - instance and `self._port` is updated to the resolved port path. """ if not HAS_SERIAL: @@ -121,9 +119,6 @@ async def setup(self): f"Import error: {_SERIAL_IMPORT_ERROR}" ) - loop = asyncio.get_running_loop() - self._executor = ThreadPoolExecutor(max_workers=1) - # 1. VID:PID specified - port maybe if self._vid is not None and self._pid is not None: matching_ports = [ @@ -174,43 +169,33 @@ def _open_serial() -> serial.Serial: ) try: - self._ser = await loop.run_in_executor(self._executor, _open_serial) + async with self._lock: + self._ser = await anyio.to_thread.run_sync(_open_serial) except serial.SerialException as e: logger.error( f"Could not connect to device '{self._human_readable_device_name}', is it in use by a different notebook/process?" ) - if self._executor is not None: - self._executor.shutdown(wait=True) - self._executor = None raise e assert self._ser is not None - self._port = candidate_port - async def stop(self): - """Close the serial device.""" + async def _cleanup(): + if self._ser is not None and self._ser.is_open: + async with self._lock: + await anyio.to_thread.run_sync(self._ser.close) - if self._ser is not None and self._ser.is_open: - loop = asyncio.get_running_loop() - - if self._executor is None: - raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") - await loop.run_in_executor(self._executor, self._ser.close) - - if self._executor is not None: - self._executor.shutdown(wait=True) - self._executor = None + stack.push_shielded_async_callback(_cleanup) async def write(self, data: bytes): """Write data to the serial device.""" - loop = asyncio.get_running_loop() - if self._executor is None or self._ser is None: + if self._ser is None: raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") - await loop.run_in_executor(self._executor, self._ser.write, data) + async with self._lock: + await anyio.to_thread.run_sync(self._ser.write, data) logger.log(LOG_LEVEL_IO, "[%s] write %s", self._port, data) capturer.record( @@ -220,11 +205,11 @@ async def write(self, data: bytes): async def read(self, num_bytes: int = 1) -> bytes: """Read data from the serial device.""" - loop = asyncio.get_running_loop() - if self._executor is None or self._ser is None: + if self._ser is None: raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") - data = await loop.run_in_executor(self._executor, self._ser.read, num_bytes) + async with self._lock: + data = await anyio.to_thread.run_sync(self._ser.read, num_bytes) if len(data) != 0: logger.log(LOG_LEVEL_IO, "[%s] read %s", self._port, data) @@ -237,11 +222,11 @@ async def read(self, num_bytes: int = 1) -> bytes: async def readline(self) -> bytes: # type: ignore # very dumb it's reading from pyserial """Read a line from the serial device.""" - loop = asyncio.get_running_loop() - if self._executor is None or self._ser is None: + if self._ser is None: raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") - data = await loop.run_in_executor(self._executor, self._ser.readline) + async with self._lock: + data = await anyio.to_thread.run_sync(self._ser.readline) if len(data) != 0: logger.log(LOG_LEVEL_IO, "[%s] readline %s", self._port, data) @@ -254,32 +239,27 @@ async def readline(self) -> bytes: # type: ignore # very dumb it's reading from async def send_break(self, duration: float): """Send a break condition for the specified duration.""" - loop = asyncio.get_running_loop() - if self._executor is None or self._ser is None: + if self._ser is None: raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") - def _send_break(ser, duration: float) -> None: - """Send a break condition for the specified duration.""" - assert ser is not None, "forgot to call setup?" - ser.send_break(duration=duration) - - await loop.run_in_executor(self._executor, lambda: _send_break(self._ser, duration=duration)) + async with self._lock: + await anyio.to_thread.run_sync(self._ser.send_break, duration) logger.log(LOG_LEVEL_IO, "[%s] send_break %s", self._port, duration) capturer.record(SerialCommand(device_id=self.port, action="send_break", data=str(duration))) async def reset_input_buffer(self): - loop = asyncio.get_running_loop() - if self._executor is None or self._ser is None: + if self._ser is None: raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") - await loop.run_in_executor(self._executor, self._ser.reset_input_buffer) + async with self._lock: + await anyio.to_thread.run_sync(self._ser.reset_input_buffer) logger.log(LOG_LEVEL_IO, "[%s] reset_input_buffer", self._port) capturer.record(SerialCommand(device_id=self.port, action="reset_input_buffer", data="")) async def reset_output_buffer(self): - loop = asyncio.get_running_loop() - if self._executor is None or self._ser is None: + if self._ser is None: raise RuntimeError(f"Call setup() first for device '{self._human_readable_device_name}'.") - await loop.run_in_executor(self._executor, self._ser.reset_output_buffer) + async with self._lock: + await anyio.to_thread.run_sync(self._ser.reset_output_buffer) logger.log(LOG_LEVEL_IO, "[%s] reset_output_buffer", self._port) capturer.record(SerialCommand(device_id=self.port, action="reset_output_buffer", data="")) @@ -374,7 +354,7 @@ def __init__( ) self.cr = cr - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): pass async def write(self, data: bytes): diff --git a/pylabrobot/io/sila/discovery.py b/pylabrobot/io/sila/discovery.py index a65e9c2d42b..49625397e87 100644 --- a/pylabrobot/io/sila/discovery.py +++ b/pylabrobot/io/sila/discovery.py @@ -11,7 +11,6 @@ from __future__ import annotations -import asyncio import dataclasses import logging import os @@ -23,6 +22,8 @@ import xml.etree.ElementTree as ET from typing import TYPE_CHECKING, AsyncGenerator, Optional +import anyio + try: from zeroconf import ServiceBrowser, ServiceListener, Zeroconf @@ -174,7 +175,6 @@ async def _netbios_scan(interface: str, timeout: float = 3.0) -> dict[str, str]: Returns a dict mapping IP -> NetBIOS name. """ - loop = asyncio.get_running_loop() results: dict[str, str] = {} sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -182,39 +182,34 @@ async def _netbios_scan(interface: str, timeout: float = 3.0) -> dict[str, str]: sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) sock.bind((interface, 0)) # Use a short blocking timeout so recvfrom in the executor thread returns - # promptly rather than blocking forever, while still allowing the asyncio - # wait_for to enforce the overall deadline. + # promptly rather than blocking forever, while still allowing AnyIO + # move_on_after to enforce the overall deadline. sock.settimeout(0.5) # Link-local is always a /16 subnet (169.254.0.0/16), so broadcast to x.x.255.255. parts = interface.split(".") broadcast = f"{parts[0]}.{parts[1]}.255.255" - # Use run_in_executor for sendto/recvfrom since the async loop equivalents - # (loop.sock_sendto / loop.sock_recvfrom) require Python 3.11+. try: - await loop.run_in_executor(None, lambda: sock.sendto(_NBNS_WILDCARD_QUERY, (broadcast, 137))) + await anyio.to_thread.run_sync(lambda: sock.sendto(_NBNS_WILDCARD_QUERY, (broadcast, 137))) except OSError: logger.debug("NetBIOS broadcast failed on %s", interface) sock.close() return results - deadline = loop.time() + timeout - while loop.time() < deadline: - try: - data, (addr, _) = await asyncio.wait_for( - loop.run_in_executor(None, lambda: sock.recvfrom(65535)), - timeout=max(0.1, deadline - loop.time()), - ) - except (asyncio.TimeoutError, socket.timeout, OSError): - continue + with anyio.move_on_after(timeout): + while True: + try: + data, (addr, _) = await anyio.to_thread.run_sync(lambda: sock.recvfrom(65535)) + except (socket.timeout, OSError): + continue - if addr == interface: - continue + if addr == interface: + continue - name = _decode_nbns_name(data) - if name: - results[addr] = name + name = _decode_nbns_name(data) + if name: + results[addr] = name sock.close() return results @@ -231,38 +226,35 @@ async def _ping_broadcast(interface: str) -> None: interface — without it the broadcast goes out on the default route. On Linux, ``ping -I `` serves the same purpose. """ - loop = asyncio.get_running_loop() parts = interface.split(".") broadcast = f"{parts[0]}.{parts[1]}.255.255" if sys.platform == "win32": cmd = ["ping", "-n", "3", "-w", "1000", broadcast] elif sys.platform == "linux": - iface_name = await loop.run_in_executor(None, _interface_name_for_ip_sync, interface) + iface_name = await anyio.to_thread.run_sync(_interface_name_for_ip_sync, interface) if iface_name: cmd = ["ping", "-c", "3", "-W", "1", "-I", iface_name, broadcast] else: cmd = ["ping", "-c", "3", "-W", "1", broadcast] else: # macOS / BSD: -b binds to a named interface - iface_name = await loop.run_in_executor(None, _interface_name_for_ip_sync, interface) + iface_name = await anyio.to_thread.run_sync(_interface_name_for_ip_sync, interface) if iface_name: cmd = ["ping", "-c", "3", "-W", "1", "-b", iface_name, broadcast] else: cmd = ["ping", "-c", "3", "-W", "1", broadcast] try: - proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await asyncio.wait_for(proc.wait(), timeout=5) - except (FileNotFoundError, asyncio.TimeoutError): + with anyio.move_on_after(5): + await anyio.run_process(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError: + pass + except FileNotFoundError: pass # Give devices a moment to respond so ARP entries are populated. - await asyncio.sleep(0.5) + await anyio.sleep(0.5) async def _arp_scan(interface: str) -> dict[str, str]: @@ -294,21 +286,19 @@ async def _arp_scan_bsd(interface: str) -> dict[str, str]: ? (169.254.245.237) at 0:5:51:e:e5:7e on en13 [ethernet] """ # Resolve our IP to an interface name (e.g. "en13") so we can filter ARP entries. - loop = asyncio.get_running_loop() - iface_name = await loop.run_in_executor(None, _interface_name_for_ip_sync, interface) + iface_name = await anyio.to_thread.run_sync(_interface_name_for_ip_sync, interface) if not iface_name: logger.debug("could not resolve interface name for %s, skipping ARP scan", interface) return {} try: - proc = await asyncio.create_subprocess_exec( - "arp", - "-an", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.DEVNULL, - ) - stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5) - except (FileNotFoundError, asyncio.TimeoutError): + with anyio.move_on_after(5) as cancel_scope: + result = await anyio.run_process(["arp", "-an"]) + stdout = result.stdout + + if cancel_scope.cancel_called: + return {} + except (FileNotFoundError, subprocess.CalledProcessError): return {} results: dict[str, str] = {} @@ -342,19 +332,16 @@ async def _arp_scan_linux(interface: str) -> dict[str, str]: # Fall back to arp -an on non-procfs Linux systems. return await _arp_scan_bsd(interface) - loop = asyncio.get_running_loop() - try: + from anyio import Path - def _read(): - with open("/proc/net/arp") as f: - return f.read() - - text = await loop.run_in_executor(None, _read) + try: + path = Path("/proc/net/arp") + text = await path.read_text() except OSError: return {} # Determine the OS-level interface name for our IP so we can filter entries. - iface_name = await loop.run_in_executor(None, _interface_name_for_ip_sync, interface) + iface_name = await anyio.to_thread.run_sync(_interface_name_for_ip_sync, interface) if not iface_name: logger.debug("could not resolve interface name for %s, skipping ARP scan", interface) return {} @@ -390,14 +377,13 @@ async def _arp_scan_windows(interface: str) -> dict[str, str]: Windows groups entries by interface, so we find the section matching our IP. """ try: - proc = await asyncio.create_subprocess_exec( - "arp", - "-a", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.DEVNULL, - ) - stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5) - except (FileNotFoundError, asyncio.TimeoutError): + with anyio.move_on_after(5) as cancel_scope: + result = await anyio.run_process(["arp", "-a"]) + stdout = result.stdout + + if cancel_scope.cancel_called: + return {} + except (FileNotFoundError, subprocess.CalledProcessError): return {} results: dict[str, str] = {} @@ -482,30 +468,18 @@ async def _get_device_identification( f"\r\n" ).encode() + body - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: - if interface: - sock.bind((interface, 0)) - sock.setblocking(False) - - loop = asyncio.get_running_loop() - deadline = loop.time() + timeout - - def _remaining() -> float: - return max(0.01, deadline - loop.time()) + with anyio.fail_after(timeout): + async with await anyio.connect_tcp(host, port, local_host=interface) as stream: + await stream.send(request) - await asyncio.wait_for(loop.sock_connect(sock, (host, port)), timeout=_remaining()) - await asyncio.wait_for(loop.sock_sendall(sock, request), timeout=_remaining()) - - resp = b"" - while True: - try: - chunk = await asyncio.wait_for(loop.sock_recv(sock, 4096), timeout=_remaining()) - if not chunk: - break - resp += chunk - except asyncio.TimeoutError: - break + resp = b"" + try: + while True: + chunk = await stream.receive() + resp += chunk + except anyio.EndOfStream: + pass # Extract XML from HTTP response text = resp.decode("utf-8", errors="replace") @@ -518,10 +492,8 @@ def _remaining() -> float: xml_text = xml_text[: -len(suffix)] break return _parse_device_identification(host, port, xml_text.encode("utf-8")) - except (OSError, asyncio.TimeoutError): + except (OSError, TimeoutError): pass - finally: - sock.close() return None @@ -539,42 +511,55 @@ async def _discover_sila1( logger.debug("no interface provided for SiLA 1 discovery, skipping") return [] - loop = asyncio.get_running_loop() - deadline = loop.time() + timeout + devices: list[SiLADevice] = [] + hosts: dict[str, str] = {} - # Run host discovery methods in parallel. - # Cap NetBIOS at 3s — any device that responds will do so within a second or two. - netbios_task = asyncio.ensure_future(_netbios_scan(interface, timeout=min(timeout, 3.0))) - arp_task = asyncio.ensure_future(_arp_scan(interface)) - scan_results = await asyncio.gather(netbios_task, arp_task, return_exceptions=True) + with anyio.move_on_after(timeout): + # Run host discovery methods in parallel. + # Cap NetBIOS at 3s — any device that responds will do so within a second or two. + scan_results = {} + async with anyio.create_task_group() as tg: - hosts: dict[str, str] = {} - if isinstance(scan_results[0], dict): - hosts.update(scan_results[0]) - if isinstance(scan_results[1], dict): - for ip, name in scan_results[1].items(): + async def do_netbios(): + scan_results["netbios"] = await _netbios_scan(interface, timeout=min(timeout, 3.0)) + + async def do_arp(): + scan_results["arp"] = await _arp_scan(interface) + + tg.start_soon(do_netbios) + tg.start_soon(do_arp) + + hosts.update(scan_results.get("netbios", {})) + for ip, name in scan_results.get("arp", {}).items(): if ip not in hosts: logger.debug("found %s via ARP (not NetBIOS)", ip) hosts[ip] = name - if not hosts: - return [] + if not hosts: + return [] - remaining = max(0.01, deadline - loop.time()) - devices: list[SiLADevice] = [] - host_list = [ip for ip in hosts if not ip.endswith(".255")] - coros = [ - _get_device_identification(ip, port, interface=interface, timeout=remaining) for ip in host_list - ] - results = await asyncio.gather(*coros, return_exceptions=True) - for ip, r in zip(host_list, results): - if isinstance(r, SiLADevice): - devices.append(r) - else: - # Host is reachable but didn't respond to GetDeviceIdentification. - # Include it with whatever we know (name from NetBIOS, or just the IP). - name = hosts.get(ip, "") or ip - devices.append(SiLADevice(host=ip, port=port, name=name, sila_version=1)) + host_list = [ip for ip in hosts if not ip.endswith(".255")] + + identification_results = {} + async with anyio.create_task_group() as tg: + + async def do_query(ip): + identification_results[ip] = await _get_device_identification( + ip, port, interface=interface, timeout=timeout + ) + + for ip in host_list: + tg.start_soon(do_query, ip) + + for ip in host_list: + r = identification_results.get(ip) + if isinstance(r, SiLADevice): + devices.append(r) + else: + # Host is reachable but didn't respond to GetDeviceIdentification. + # Include it with whatever we know (name from NetBIOS, or just the IP). + name = hosts.get(ip, "") or ip + devices.append(SiLADevice(host=ip, port=port, name=name, sila_version=1)) return devices @@ -606,14 +591,14 @@ def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: def update_service(self, zc: Zeroconf, type_: str, name: str) -> None: pass - loop = asyncio.get_running_loop() - zc = await loop.run_in_executor(None, Zeroconf) + zc = await anyio.to_thread.run_sync(Zeroconf) try: listener = _Listener() - await loop.run_in_executor(None, lambda: ServiceBrowser(zc, SILA_MDNS_TYPE, listener)) - await asyncio.sleep(timeout) + await anyio.to_thread.run_sync(lambda: ServiceBrowser(zc, SILA_MDNS_TYPE, listener)) + await anyio.sleep(timeout) finally: - await loop.run_in_executor(None, zc.close) + with anyio.CancelScope(shield=True): + await anyio.to_thread.run_sync(zc.close) return devices @@ -647,33 +632,31 @@ async def discover_iter( if not interfaces: logger.debug("no link-local interfaces found, SiLA 1 discovery will be skipped") - tasks = [ - asyncio.ensure_future(c) - for c in [_discover_sila1(timeout=timeout, interface=iface) for iface in interfaces] - + [_discover_sila2(timeout)] - ] + send_stream, receive_stream = anyio.create_memory_object_stream(100) - seen: set[tuple[str, int]] = set() - pending = set(tasks) - try: - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for t in done: - try: - result = t.result() - except Exception: - continue + async def worker(s_stream, func, *args): + async with s_stream: + try: + result = await func(*args) if isinstance(result, list): for d in result: - key = (d.host, d.port) - if key not in seen: - seen.add(key) - yield d - finally: - for t in pending: - t.cancel() - if pending: - await asyncio.gather(*pending, return_exceptions=True) + await s_stream.send(d) + except Exception: + pass + + seen: set[tuple[str, int]] = set() + + async with anyio.create_task_group() as tg: + async with send_stream: + for iface in interfaces: + tg.start_soon(worker, send_stream.clone(), _discover_sila1, timeout, iface) + tg.start_soon(worker, send_stream.clone(), _discover_sila2, timeout) + + async for d in receive_stream: + key = (d.host, d.port) + if key not in seen: + seen.add(key) + yield d async def discover( @@ -725,4 +708,4 @@ async def main(): if not found: print("No SiLA devices found.") - asyncio.run(main()) + anyio.run(main) diff --git a/pylabrobot/io/sila/discovery_tests.py b/pylabrobot/io/sila/discovery_tests.py index 5f96552b8bc..0b8086d0486 100644 --- a/pylabrobot/io/sila/discovery_tests.py +++ b/pylabrobot/io/sila/discovery_tests.py @@ -1,4 +1,3 @@ -import asyncio import socket import struct import unittest @@ -14,6 +13,7 @@ _discover_sila2, _parse_device_identification, ) +from pylabrobot.testing.concurrency import AnyioTestBase class TestSiLADevice(unittest.TestCase): @@ -135,7 +135,7 @@ def test_invalid_xml(self): self.assertIsNone(_parse_device_identification("10.0.0.1", 8080, b"not xml")) -class TestArpScanBsd(unittest.TestCase): +class TestArpScanBsd(AnyioTestBase): ARP_OUTPUT = ( "? (169.254.245.237) at 0:5:51:e:e5:7e on en13 [ethernet]\n" "? (192.168.0.1) at aa:bb:cc:dd:ee:ff on en0 ifscope [ethernet]\n" @@ -145,13 +145,13 @@ class TestArpScanBsd(unittest.TestCase): ) @patch("pylabrobot.io.sila.discovery._interface_name_for_ip_sync", return_value="en13") - @patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) - def test_parses_link_local_entries(self, mock_exec, _mock_iface): - mock_proc = AsyncMock() - mock_proc.communicate.return_value = (self.ARP_OUTPUT.encode(), b"") - mock_exec.return_value = mock_proc + @patch("anyio.run_process", new_callable=AsyncMock) + async def test_parses_link_local_entries(self, mock_run_process, _mock_iface): + mock_result = MagicMock() + mock_result.stdout = self.ARP_OUTPUT.encode() + mock_run_process.return_value = mock_result - results = asyncio.run(_arp_scan_bsd("169.254.229.18")) + results = await _arp_scan_bsd("169.254.229.18") self.assertIn("169.254.245.237", results) self.assertIn("169.254.99.1", results) # Non-link-local should be excluded @@ -164,23 +164,23 @@ def test_parses_link_local_entries(self, mock_exec, _mock_iface): self.assertNotIn("169.254.50.50", results) @patch("pylabrobot.io.sila.discovery._interface_name_for_ip_sync", return_value="en13") - @patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) - def test_empty_output(self, mock_exec, _mock_iface): - mock_proc = AsyncMock() - mock_proc.communicate.return_value = (b"", b"") - mock_exec.return_value = mock_proc + @patch("anyio.run_process", new_callable=AsyncMock) + async def test_empty_output(self, mock_run_process, _mock_iface): + mock_result = MagicMock() + mock_result.stdout = b"" + mock_run_process.return_value = mock_result - results = asyncio.run(_arp_scan_bsd("169.254.229.18")) + results = await _arp_scan_bsd("169.254.229.18") self.assertEqual(results, {}) @patch("pylabrobot.io.sila.discovery._interface_name_for_ip_sync", return_value=None) - def test_returns_empty_when_interface_unknown(self, _mock_iface): + async def test_returns_empty_when_interface_unknown(self, _mock_iface): """If we can't resolve the interface name, return empty rather than all entries.""" - results = asyncio.run(_arp_scan_bsd("169.254.229.18")) + results = await _arp_scan_bsd("169.254.229.18") self.assertEqual(results, {}) -class TestArpScanLinux(unittest.TestCase): +class TestArpScanLinux(AnyioTestBase): PROC_NET_ARP = ( "IP address HW type Flags HW address Mask Device\n" "169.254.245.237 0x1 0x2 00:05:51:0e:e5:7e * eth0\n" @@ -190,18 +190,10 @@ class TestArpScanLinux(unittest.TestCase): @patch("pylabrobot.io.sila.discovery._interface_name_for_ip_sync", return_value="eth0") @patch("os.path.exists", return_value=True) - def test_parses_proc_net_arp(self, _mock_exists, _mock_iface): - with patch( - "builtins.open", - MagicMock( - return_value=MagicMock( - __enter__=lambda s: s, - __exit__=MagicMock(return_value=False), - read=MagicMock(return_value=self.PROC_NET_ARP), - ) - ), - ): - results = asyncio.run(_arp_scan_linux("169.254.229.18")) + async def test_parses_proc_net_arp(self, _mock_exists, _mock_iface): + with patch("anyio.Path.read_text", new_callable=AsyncMock) as mock_read_text: + mock_read_text.return_value = self.PROC_NET_ARP + results = await _arp_scan_linux("169.254.229.18") self.assertIn("169.254.245.237", results) # Non-link-local excluded @@ -210,7 +202,7 @@ def test_parses_proc_net_arp(self, _mock_exists, _mock_iface): self.assertNotIn("169.254.10.20", results) -class TestArpScanWindows(unittest.TestCase): +class TestArpScanWindows(AnyioTestBase): ARP_OUTPUT = ( "\r\n" "Interface: 169.254.229.18 --- 0x5\r\n" @@ -224,13 +216,13 @@ class TestArpScanWindows(unittest.TestCase): " 169.254.99.1 11-22-33-44-55-66 dynamic\r\n" ) - @patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) - def test_parses_correct_interface_section(self, mock_exec): - mock_proc = AsyncMock() - mock_proc.communicate.return_value = (self.ARP_OUTPUT.encode(), b"") - mock_exec.return_value = mock_proc + @patch("anyio.run_process", new_callable=AsyncMock) + async def test_parses_correct_interface_section(self, mock_run_process): + mock_result = MagicMock() + mock_result.stdout = self.ARP_OUTPUT.encode() + mock_run_process.return_value = mock_result - results = asyncio.run(_arp_scan_windows("169.254.229.18")) + results = await _arp_scan_windows("169.254.229.18") self.assertIn("169.254.245.237", results) self.assertIn("169.254.10.20", results) # This is under a different interface section @@ -240,16 +232,16 @@ def test_parses_correct_interface_section(self, mock_exec): self.assertNotIn("169.254.229.18", results) -class TestDiscoverSila2(unittest.TestCase): +class TestDiscoverSila2(AnyioTestBase): @patch("pylabrobot.io.sila.discovery.HAS_ZEROCONF", False) - def test_no_zeroconf_returns_empty(self): - devices = asyncio.run(_discover_sila2(timeout=0.1)) + async def test_no_zeroconf_returns_empty(self): + devices = await _discover_sila2(timeout=0.1) self.assertEqual(devices, []) @unittest.skipIf(not HAS_ZEROCONF, "zeroconf not installed") @patch("pylabrobot.io.sila.discovery.Zeroconf", create=True) @patch("pylabrobot.io.sila.discovery.ServiceBrowser", create=True) - def test_discovers_device(self, mock_browser_cls, mock_zc_cls): + async def test_discovers_device(self, mock_browser_cls, mock_zc_cls): mock_zc = MagicMock() mock_zc_cls.return_value = mock_zc @@ -264,13 +256,9 @@ def side_effect(zc, type_, listener): mock_browser_cls.side_effect = side_effect - devices = asyncio.run(_discover_sila2(timeout=0.1)) + devices = await _discover_sila2(timeout=0.1) self.assertEqual(len(devices), 1) self.assertEqual(devices[0].host, "192.168.1.42") self.assertEqual(devices[0].port, 8091) self.assertEqual(devices[0].name, "Pico.local.") self.assertEqual(devices[0].sila_version, 2) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/io/sila/grpc_tests.py b/pylabrobot/io/sila/grpc_tests.py index bdb8bdcfc3f..9ee2e48006e 100644 --- a/pylabrobot/io/sila/grpc_tests.py +++ b/pylabrobot/io/sila/grpc_tests.py @@ -355,7 +355,3 @@ def test_no_details_method(self): error = MagicMock(spec=[]) result = decode_grpc_error(error) self.assertIsInstance(result, str) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/io/socket.py b/pylabrobot/io/socket.py index 0fcae09a6bd..76b763e1caa 100644 --- a/pylabrobot/io/socket.py +++ b/pylabrobot/io/socket.py @@ -1,9 +1,13 @@ -import asyncio import logging import ssl from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional +import anyio +import anyio.streams.buffered +import anyio.streams.tls + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.capture import Command, capturer, get_capture_or_validation_active from pylabrobot.io.errors import ValidationError from pylabrobot.io.io import IOBase @@ -41,52 +45,55 @@ def __init__( self._human_readable_device_name = human_readable_device_name self._host = host self._port = port - self._reader: Optional[asyncio.StreamReader] = None - self._writer: Optional[asyncio.StreamWriter] = None + self._stream: Optional[anyio.streams.buffered.BufferedByteStream] = None self._read_timeout = read_timeout self._write_timeout = write_timeout self._ssl_context = ssl_context self._server_hostname = server_hostname self._unique_id = f"{self._host}:{self._port}" - self._read_lock = asyncio.Lock() - self._write_lock = asyncio.Lock() + self._read_lock = anyio.Lock() + self._write_lock = anyio.Lock() self._ssl = ssl if get_capture_or_validation_active(): raise RuntimeError("Cannot create a new Socket object while capture or validation is active") - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) await self._connect() + stack.push_async_callback(self._disconnect) async def _connect(self): - self._reader, self._writer = await asyncio.open_connection( - host=self._host, - port=self._port, - ssl=self._ssl_context, - server_hostname=self._server_hostname, - ) - - async def stop(self): - await self._disconnect() + raw_stream = await anyio.connect_tcp(self._host, self._port) + stream: Any + if self._ssl_context: + stream = await anyio.streams.tls.TLSStream.wrap( + raw_stream, + ssl_context=self._ssl_context, + server_hostname=self._server_hostname, + ) # type: ignore[call-arg] + else: + stream = raw_stream + self._stream = anyio.streams.buffered.BufferedByteStream(stream) async def _disconnect(self): async with self._read_lock, self._write_lock: - self._reader = None - if self._writer is None: + if self._stream is None: return logger.info("Closing connection to socket %s:%s", self._host, self._port) try: - self._writer.close() - await self._writer.wait_closed() + await self._stream.aclose() except OSError as e: logger.warning("Error while closing socket connection: %s", e) finally: - self._writer = None + self._stream = None - async def reconnect(self): + async def reconnect(self, *, wait_time: float = 0): await self._disconnect() + if wait_time > 0: + await anyio.sleep(wait_time) await self._connect() def serialize(self): @@ -114,16 +121,15 @@ def deserialize(cls, data: dict) -> "Socket": ) async def write(self, data: bytes, timeout: Optional[float] = None) -> None: - """Wrapper around StreamWriter.write with lock and io logging. + """Wrapper around anyio.abc.ByteStream.send with lock and io logging. Does not retry on timeouts. """ - if self._writer is None: + if self._stream is None: raise RuntimeError( f"Socket for '{self._human_readable_device_name}' not set up; call setup() first" ) timeout = self._write_timeout if timeout is None else timeout async with self._write_lock: - self._writer.write(data) logger.log(LOG_LEVEL_IO, "[%s:%d] write %s", self._host, self._port, data) capturer.record( SocketCommand( @@ -133,9 +139,9 @@ async def write(self, data: bytes, timeout: Optional[float] = None) -> None: ) ) try: - await asyncio.wait_for(self._writer.drain(), timeout=timeout) - return - except asyncio.TimeoutError as exc: + with anyio.fail_after(timeout): + await self._stream.send(data) + except TimeoutError as exc: logger.error("write timeout: %r", exc) raise TimeoutError(f"Timeout while writing to socket after {timeout} seconds") from exc except (ConnectionResetError, OSError) as e: @@ -143,7 +149,7 @@ async def write(self, data: bytes, timeout: Optional[float] = None) -> None: raise async def read(self, num_bytes: int = 128, timeout: Optional[float] = None) -> bytes: - """Wrapper around StreamReader.read with lock and io logging. + """Wrapper around anyio.abc.ByteStream.receive with lock and io logging. Args: num_bytes: The maximum number of bytes to read from the socket. @@ -154,17 +160,21 @@ async def read(self, num_bytes: int = 128, timeout: Optional[float] = None) -> b Returns: The data read from the socket, which may be fewer than `num_bytes` bytes. """ - if self._reader is None: + if self._stream is None: raise RuntimeError( f"Socket for '{self._human_readable_device_name}' not set up; call setup() first" ) timeout = self._read_timeout if timeout is None else timeout async with self._read_lock: try: - data = await asyncio.wait_for(self._reader.read(num_bytes), timeout=timeout) - except asyncio.TimeoutError as exc: + with anyio.fail_after(timeout): + data = await self._stream.receive(num_bytes) + except TimeoutError as exc: logger.error("read timeout: %r", exc) raise TimeoutError(f"Timeout while reading from socket after {timeout} seconds") from exc + except anyio.EndOfStream: + data = b"" + logger.log(LOG_LEVEL_IO, "[%s:%d] read %s", self._host, self._port, data) capturer.record( SocketCommand( @@ -176,51 +186,71 @@ async def read(self, num_bytes: int = 128, timeout: Optional[float] = None) -> b return data async def readline(self, timeout: Optional[float] = None) -> bytes: - """Wrapper around StreamReader.readline with lock and io logging.""" - if self._reader is None: + """Wrapper around reading from stream until newline with lock and io logging.""" + if self._stream is None: raise RuntimeError( f"Socket for '{self._human_readable_device_name}' not set up; call setup() first" ) timeout = self._read_timeout if timeout is None else timeout async with self._read_lock: try: - data = await asyncio.wait_for(self._reader.readline(), timeout=timeout) - except asyncio.TimeoutError as exc: + with anyio.fail_after(timeout): + data = await self._stream.receive_until(b"\n", max_bytes=65536) + result = data + b"\n" + except TimeoutError as exc: logger.error("readline timeout: %r", exc) raise TimeoutError(f"Timeout while reading from socket after {timeout} seconds") from exc - logger.log(LOG_LEVEL_IO, "[%s:%d] readline %s", self._host, self._port, data) + except anyio.IncompleteRead: + logger.warning("readline: connection closed before newline found, returning partial data") + result = await self._stream.receive(len(self._stream.buffer)) + except anyio.streams.buffered.DelimiterNotFound as exc: + logger.error("readline error: delimiter not found") + raise RuntimeError("Newline not found within max_bytes") from exc + + logger.log(LOG_LEVEL_IO, "[%s:%d] readline %s", self._host, self._port, result) capturer.record( SocketCommand( device_id=self._unique_id, action="readline", - data=data.hex(), + data=result.hex(), ) ) - return data + return result async def readuntil(self, separator: bytes = b"\n", timeout: Optional[float] = None) -> bytes: - """Wrapper around StreamReader.readuntil with lock and io logging. + """Wrapper around reading from stream until separator with lock and io logging. Do not retry on timeouts.""" - if self._reader is None: + if self._stream is None: raise RuntimeError( f"Socket for '{self._human_readable_device_name}' not set up; call setup() first" ) timeout = self._read_timeout if timeout is None else timeout async with self._read_lock: try: - data = await asyncio.wait_for(self._reader.readuntil(separator), timeout=timeout) - except asyncio.TimeoutError as exc: + with anyio.fail_after(timeout): + data = await self._stream.receive_until(separator, max_bytes=65536) + result = data + separator + except TimeoutError as exc: logger.error("readuntil timeout: %r", exc) raise TimeoutError(f"Timeout while reading from socket after {timeout} seconds") from exc - logger.log(LOG_LEVEL_IO, "[%s:%d] readuntil %s", self._host, self._port, data) + except anyio.IncompleteRead: + logger.warning( + "readuntil: connection closed before separator found, returning partial data" + ) + result = await self._stream.receive(len(self._stream.buffer)) + except anyio.streams.buffered.DelimiterNotFound as exc: + logger.error("readuntil error: delimiter not found") + raise RuntimeError("Separator not found within max_bytes") from exc + + logger.log(LOG_LEVEL_IO, "[%s:%d] readuntil %s", self._host, self._port, result) capturer.record( SocketCommand( device_id=self._unique_id, action="readuntil:" + separator.hex(), - data=data.hex(), + data=result.hex(), ) ) - return data + return result async def read_exact(self, num_bytes: int, timeout: Optional[float] = None) -> bytes: """Read exactly num_bytes, blocking until all bytes are received. @@ -239,34 +269,31 @@ async def read_exact(self, num_bytes: int, timeout: Optional[float] = None) -> b ConnectionError: If the connection is closed before num_bytes are read. TimeoutError: If timeout is reached before num_bytes are read. """ - if self._reader is None: + if self._stream is None: raise RuntimeError( f"Socket for '{self._human_readable_device_name}' not set up; call setup() first" ) timeout = self._read_timeout if timeout is None else timeout - data = bytearray() async with self._read_lock: - while len(data) < num_bytes: - remaining = num_bytes - len(data) - try: - chunk = await asyncio.wait_for(self._reader.read(remaining), timeout=timeout) - except asyncio.TimeoutError as exc: - logger.error("read_exact timeout: %r", exc) - raise TimeoutError(f"Timeout while reading from socket after {timeout} seconds") from exc - if len(chunk) == 0: - raise ConnectionError("Connection closed before num_bytes are read") - data.extend(chunk) + try: + with anyio.fail_after(timeout): + data = await self._stream.receive_exactly(num_bytes) + except TimeoutError as exc: + logger.error("read_exact timeout: %r", exc) + raise TimeoutError(f"Timeout while reading from socket after {timeout} seconds") from exc + except anyio.IncompleteRead as exc: + logger.error("read_exact error: %r", exc) + raise ConnectionError("Connection closed before num_bytes were read") from exc - result = bytes(data) - logger.log(LOG_LEVEL_IO, "[%s:%d] read_exact %s", self._host, self._port, result.hex()) + logger.log(LOG_LEVEL_IO, "[%s:%d] read_exact %s", self._host, self._port, data.hex()) capturer.record( SocketCommand( device_id=self._unique_id, action="read_exact", - data=result.hex(), + data=data.hex(), ) ) - return result + return data async def read_until_eof(self, chunk_size: int = 1024, timeout: Optional[float] = None) -> bytes: """Read until EOF is reached. @@ -277,34 +304,35 @@ async def read_until_eof(self, chunk_size: int = 1024, timeout: Optional[float] async with self._read_lock: while True: - if self._reader is None: + if self._stream is None: raise RuntimeError( f"Socket for '{self._human_readable_device_name}' not set up; call setup() first" ) try: - chunk = await asyncio.wait_for(self._reader.read(chunk_size), timeout=timeout) - except asyncio.TimeoutError as exc: + with anyio.fail_after(timeout): + chunk = await self._stream.receive(chunk_size) + except TimeoutError as exc: # if some previous read attempts already return some data, we should consider this a success if len(buf) > 0: break logger.error("read_until_eof timeout: %r", exc) raise TimeoutError(f"Timeout while reading from socket after {timeout} seconds") from exc - if len(chunk) == 0: + except anyio.EndOfStream: break logger.debug("read_until_eof: got %d bytes", len(chunk)) buf.extend(chunk) - line = bytes(buf) - logger.log(LOG_LEVEL_IO, "[%s:%d] read_until_eof %s", self._host, self._port, line) + result = bytes(buf) + logger.log(LOG_LEVEL_IO, "[%s:%d] read_until_eof %s", self._host, self._port, result) capturer.record( SocketCommand( device_id=self._unique_id, action="read_until_eof", - data=line.hex(), + data=result.hex(), ) ) - return line + return result class SocketValidator(Socket): diff --git a/pylabrobot/io/usb.py b/pylabrobot/io/usb.py index 14d2802f986..62cb0442f35 100644 --- a/pylabrobot/io/usb.py +++ b/pylabrobot/io/usb.py @@ -1,10 +1,10 @@ -import asyncio +import contextlib import logging -import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, List, Optional +import anyio + from pylabrobot.io.capture import Command, capturer, get_capture_or_validation_active from pylabrobot.io.errors import ValidationError from pylabrobot.io.io import IOBase @@ -102,8 +102,6 @@ def __init__( self.read_endpoint: Optional[usb.core.Endpoint] = None self.write_endpoint: Optional[usb.core.Endpoint] = None - self._executor: Optional[ThreadPoolExecutor] = None - # unique id in the logs self._unique_id = f"[{hex(self._id_vendor)}:{hex(self._id_product)}][{self._serial_number or ''}][{self._device_address or ''}]" self._human_readable_device_name = human_readable_device_name @@ -117,30 +115,28 @@ async def write(self, data: bytes, timeout: Optional[float] = None): (specified by the `write_timeout` attribute). """ - if self.dev is None or self.read_endpoint is None: + dev = self.dev + write_endpoint = self.write_endpoint + if dev is None or self.read_endpoint is None or write_endpoint is None: raise RuntimeError(f"USB device for '{self._human_readable_device_name}' is not connected.") if timeout is None: timeout = self.write_timeout # write command to endpoint - loop = asyncio.get_running_loop() - write_endpoint = self.write_endpoint - dev = self.dev - if self._executor is None or dev is None or write_endpoint is None: - raise RuntimeError(f"Call setup() first for USB device '{self._human_readable_device_name}'.") - await loop.run_in_executor( - self._executor, - lambda: dev.write( - write_endpoint, data, timeout=int(timeout * 1000) - ), # PyUSB expects timeout in milliseconds - ) - if len(data) % write_endpoint.wMaxPacketSize == 0: - # send a zero-length packet to indicate the end of the transfer - await loop.run_in_executor( - self._executor, - lambda: dev.write(write_endpoint, b"", timeout=int(timeout * 1000)), - ) + async def write(d): + t = anyio.current_effective_deadline() - anyio.current_time() + assert t < float("inf"), "Timeout must be set" + timeout_ms = int(t * 1000) + await anyio.to_thread.run_sync(lambda: dev.write(write_endpoint, d, timeout=timeout_ms)) + + with contextlib.ExitStack() as stack: + if timeout is not None: + stack.enter_context(anyio.fail_after(timeout)) + await write(data) + if len(data) % write_endpoint.wMaxPacketSize == 0: + # send a zero-length packet to indicate the end of the transfer + await write(b"") logger.log(LOG_LEVEL_IO, "%s write: %s", self._unique_id, data) capturer.record( USBCommand( @@ -150,7 +146,7 @@ async def write(self, data: bytes, timeout: Optional[float] = None): ) ) - def _read_packet( + async def _read_packet( self, size: Optional[int] = None, timeout: Optional[float] = None, @@ -167,8 +163,8 @@ def _read_packet( Returns: A bytearray containing the data read, or None if no data was received. """ - - if self.dev is None or self.read_endpoint is None: + dev = self.dev + if dev is None or self.read_endpoint is None: raise RuntimeError(f"USB device for '{self._human_readable_device_name}' is not connected.") ep = endpoint if endpoint is not None else self.read_endpoint @@ -179,7 +175,7 @@ def _read_packet( if size is None: if isinstance(ep, int): # Find endpoint object to get max packet size - cfg = self.dev.get_active_configuration() + cfg = dev.get_active_configuration() intf = cfg[(0, 0)] ep_obj = usb.util.find_descriptor( intf, @@ -197,16 +193,20 @@ def _read_packet( timeout = self.packet_read_timeout try: - res = self.dev.read( - ep, - read_size, - timeout=int(timeout * 1000), # timeout in ms - ) + with anyio.fail_after(timeout): + res = await anyio.to_thread.run_sync( + lambda: dev.read( + ep, + read_size, + timeout=int(timeout * 1000), # timeout in ms + ), + abandon_on_cancel=True, + ) if res is not None: return bytearray(res) return None - except usb.core.USBError: + except (usb.core.USBError, TimeoutError): # No data available (yet), this will give a timeout error. Don't reraise. return None @@ -226,48 +226,42 @@ async def read(self, timeout: Optional[int] = None, size: Optional[int] = None) if timeout is None: timeout = self.read_timeout - def read_or_timeout(): - # Attempt to read packets until timeout, or when we identify the right id. - timeout_time = time.time() + timeout - - while time.time() < timeout_time: - # read response from endpoint, and keep reading until the packet is smaller than the max - # packet size: if the packet is that size, it means that there may be more data to read. - resp = bytearray() - last_packet: Optional[bytearray] = None - while True: # read while we have data, and while the last packet is the max size. - remaining = size - len(resp) if size is not None else None - last_packet = self._read_packet(size=remaining) - if last_packet is not None: - resp += last_packet - if self.read_endpoint is None: - raise RuntimeError("Read endpoint is None. Call setup() first.") - if last_packet is None or len(last_packet) != self.read_endpoint.wMaxPacketSize: - break - if size is not None and len(resp) >= size: - break - - if len(resp) == 0: - continue - - logger.log(LOG_LEVEL_IO, "%s read: %s", self._unique_id, resp) - capturer.record( - USBCommand( - device_id=self._unique_id, - action="read", - data=resp.decode("unicode_escape", errors="backslashreplace"), + try: + with anyio.fail_after(timeout): + while True: + # read response from endpoint, and keep reading until the packet is smaller than the max + # packet size: if the packet is that size, it means that there may be more data to read. + resp = bytearray() + last_packet: Optional[bytearray] = None + while True: # read while we have data, and while the last packet is the max size. + remaining = size - len(resp) if size is not None else None + last_packet = await self._read_packet(size=remaining) + if last_packet is not None: + resp += last_packet + if self.read_endpoint is None: + raise RuntimeError("Read endpoint is None. Call setup() first.") + if last_packet is None or len(last_packet) != self.read_endpoint.wMaxPacketSize: + break + if size is not None and len(resp) >= size: + break + + if len(resp) == 0: + continue + + logger.log(LOG_LEVEL_IO, "%s read: %s", self._unique_id, resp) + capturer.record( + USBCommand( + device_id=self._unique_id, + action="read", + data=resp.decode("unicode_escape", errors="backslashreplace"), + ) ) - ) - return resp - + return resp + except TimeoutError: + # Translate TimeoutError to a more specific error message. raise TimeoutError( f"Timeout while reading from USB device '{self._human_readable_device_name}'." - ) - - loop = asyncio.get_running_loop() - if self._executor is None or self.dev is None: - raise RuntimeError(f"Call setup() first for USB device '{self._human_readable_device_name}'.") - return await loop.run_in_executor(self._executor, read_or_timeout) + ) from None def get_available_devices(self) -> List["usb.core.Device"]: """Get a list of available devices that match the specified vendor and product IDs, and serial @@ -373,15 +367,8 @@ def ctrl_transfer( return bytearray(res) - async def setup(self, empty_buffer=True): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack, *, empty_buffer=True): """Initialize the USB connection to the machine.""" - - if self.dev is not None: - # previous setup did not properly finish, - # or we are re-initializing the device. - logger.warning("USB device already connected. Closing previous connection.") - await self.stop() - if not USE_USB: raise RuntimeError( "pyusb/libusb is not installed. Install with: pip install pylabrobot[usb]. " @@ -398,6 +385,13 @@ async def setup(self, empty_buffer=True): logger.warning("Multiple devices found. Using the first one.") self.dev = devices[0] + # at this point, we manage `self.dev`; make sure it gets cleaned up again. + @stack.callback + def cleanup(): + logger.warning("Closing connection to USB device.") + usb.util.dispose_resources(self.dev) + self.dev = None + logger.info("Found USB device.") # set the active configuration. With no arguments, the first @@ -444,23 +438,14 @@ async def setup(self, empty_buffer=True): # Empty the read buffer. if empty_buffer: - while self._read_packet() is not None: + while await self._read_packet() is not None: pass - self._executor = ThreadPoolExecutor(max_workers=self.max_workers) - - async def stop(self): - """Close the USB connection to the machine.""" - - if self.dev is None: - raise ValueError("USB device was not connected.") - logger.warning("Closing connection to USB device.") - usb.util.dispose_resources(self.dev) - self.dev = None - - if self._executor is not None: - self._executor.shutdown(wait=True) - self._executor = None + async def recover_transport(self): + """Try to recover from a broken transport.""" + # TODO: dispose of `.dev` and re-configure + while await self._read_packet() is not None: + pass def serialize(self) -> dict: """Serialize the backend to a dictionary.""" @@ -516,7 +501,7 @@ def __init__( ) self.cr = cr - async def setup(self, empty_buffer=True): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack, *, empty_buffer=True): pass async def write(self, data: bytes, timeout: Optional[float] = None): diff --git a/pylabrobot/liquid_handling/backends/backend.py b/pylabrobot/liquid_handling/backends/backend.py index a16034577ea..45a33adf6e9 100644 --- a/pylabrobot/liquid_handling/backends/backend.py +++ b/pylabrobot/liquid_handling/backends/backend.py @@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod from typing import Dict, List, Optional, Union +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.channel_positioning import GENERIC_LH_MIN_SPACING_BETWEEN_CHANNELS from pylabrobot.liquid_handling.standard import ( Drop, @@ -78,9 +79,9 @@ def head(self) -> Dict[int, TipTracker]: def head96(self) -> Optional[Dict[int, TipTracker]]: return self._head96 - async def setup(self): - """Set up the robot. This method should be called before any other method is called.""" + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): assert self._deck is not None, "Deck not set" + await super()._enter_lifespan(stack) @property @abstractmethod diff --git a/pylabrobot/liquid_handling/backends/chatterbox.py b/pylabrobot/liquid_handling/backends/chatterbox.py index 227803bc860..be3f3736eac 100644 --- a/pylabrobot/liquid_handling/backends/chatterbox.py +++ b/pylabrobot/liquid_handling/backends/chatterbox.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.backends.backend import ( LiquidHandlerBackend, ) @@ -46,12 +47,10 @@ def __init__(self, num_channels: int = 8): self._num_arms = 1 self._head96_installed = True - async def setup(self): - await super().setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up the liquid handler.") - - async def stop(self): - print("Stopping the liquid handler.") + stack.callback(lambda: print("Stopping the liquid handler.")) def serialize(self) -> dict: return {**super().serialize(), "num_channels": self.num_channels} diff --git a/pylabrobot/liquid_handling/backends/chatterbox_tests.py b/pylabrobot/liquid_handling/backends/chatterbox_tests.py index 02e8642d51b..ab9e08189ef 100644 --- a/pylabrobot/liquid_handling/backends/chatterbox_tests.py +++ b/pylabrobot/liquid_handling/backends/chatterbox_tests.py @@ -1,4 +1,4 @@ -import unittest +import contextlib from pylabrobot.liquid_handling import LiquidHandler from pylabrobot.liquid_handling.backends.chatterbox import ( @@ -10,12 +10,13 @@ hamilton_96_tiprack_1000uL_filter, ) from pylabrobot.resources.hamilton import STARLetDeck +from pylabrobot.testing.concurrency import AnyioTestBase -class ChatterboxBackendTests(unittest.IsolatedAsyncioTestCase): +class ChatterboxBackendTests(AnyioTestBase): """Tests for chatterbox backend""" - def setUp(self) -> None: + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack) -> None: self.deck = STARLetDeck() self.backend = LiquidHandlerChatterboxBackend(num_channels=8) self.lh = LiquidHandler(self.backend, deck=self.deck) @@ -23,14 +24,7 @@ def setUp(self) -> None: self.deck.assign_child_resource(self.tip_rack, rails=3) self.plate = Cor_96_wellplate_360ul_Fb(name="plate") self.deck.assign_child_resource(self.plate, rails=9) - - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - await self.lh.setup() - - async def asyncTearDown(self) -> None: - await self.lh.stop() - await super().asyncTearDown() + await stack.enter_async_context(self.lh) async def test_pick_up_tips(self): await self.lh.pick_up_tips(self.tip_rack["A1"]) diff --git a/pylabrobot/liquid_handling/backends/hamilton/STAR_backend.py b/pylabrobot/liquid_handling/backends/hamilton/STAR_backend.py index 10cfc65de36..f1723e098d7 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/STAR_backend.py +++ b/pylabrobot/liquid_handling/backends/hamilton/STAR_backend.py @@ -1,4 +1,3 @@ -import asyncio import datetime import enum import functools @@ -28,12 +27,15 @@ cast, ) +import anyio + if sys.version_info < (3, 10): from typing_extensions import Concatenate, ParamSpec else: from typing import Concatenate, ParamSpec from pylabrobot import audio +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.heating_shaking.hamilton_backend import HamiltonHeaterShakerInterface from pylabrobot.liquid_handling.backends.hamilton.base import ( HamiltonLiquidHandler, @@ -1688,13 +1690,15 @@ def _parse_firmware_version_datetime(self, fw_version: str) -> datetime.date: raise ValueError(f"Could not parse year from firmware version string: '{fw_version}'") return datetime.date(int(year_match.group(1)), 1, 1) - async def setup( + async def _enter_lifespan( self, - skip_instrument_initialization=False, - skip_pip=False, - skip_autoload=False, - skip_iswap=False, - skip_core96_head=False, + stack: AsyncExitStackWithShielding, + *, + skip_instrument_initialization: bool = False, + skip_pip: bool = False, + skip_autoload: bool = False, + skip_iswap: bool = False, + skip_core96_head: bool = False, ): """Creates a USB connection and finds read/write interfaces. @@ -1703,8 +1707,7 @@ async def setup( skip_iswap: if True, skip initializing the iSWAP module, if applicable. skip_core96_head: if True, skip initializing the CoRe 96 head module, if applicable. """ - - await super().setup() + await super()._enter_lifespan(stack) self.id_ = 0 @@ -1794,7 +1797,11 @@ async def set_up_arm_modules(): await set_up_iswap() await set_up_core96_head() - await asyncio.gather(set_up_autoload(), set_up_arm_modules()) + async with anyio.create_task_group() as tg: + tg.start_soon(set_up_autoload) + tg.start_soon(set_up_arm_modules) + # task-group will block on exit until all tasks complete; + # unless some fail, then the remaining are cancelled. # After setup, STAR will have thrown out anything mounted on the pipetting channels, including # the core grippers. @@ -1802,9 +1809,9 @@ async def set_up_arm_modules(): self._setup_done = True - async def stop(self): - await super().stop() - self._setup_done = False + @stack.callback + def exit(): + self._setup_done = False @property def setup_done(self) -> bool: @@ -1891,11 +1898,16 @@ async def channels_request_y_minimum_spacing(self) -> List[float]: Returns: A list of minimum Y spacings in mm, one per channel. """ - return list( - await asyncio.gather( - *(self.channel_request_y_minimum_spacing(i) for i in range(self.num_channels)) - ) - ) + results: List[Optional[float]] = [None] * self.num_channels + + async def _worker(idx): + results[idx] = await self.channel_request_y_minimum_spacing(idx) + + async with anyio.create_task_group() as tg: + for idx in range(self.num_channels): + tg.start_soon(_worker, idx) + + return cast(List[float], results) def can_reach_position(self, channel_idx: int, position: Coordinate) -> bool: """Check if a position is reachable by a channel (center-based).""" @@ -1977,11 +1989,16 @@ async def channels_request_cycle_counts(self) -> List[ChannelCycleCounts]: and ``dispensing_cycles``. """ - return list( - await asyncio.gather( - *(self.channel_request_cycle_counts(channel_idx=idx) for idx in range(self.num_channels)) - ) - ) + results: List[Optional[Any]] = [None] * self.num_channels + + async def _worker(idx): + results[idx] = await self.channel_request_cycle_counts(channel_idx=idx) + + async with anyio.create_task_group() as tg: + for idx in range(self.num_channels): + tg.start_soon(_worker, idx) + + return cast(List["STARBackend.ChannelCycleCounts"], results) # # # ACTION Commands # # # @@ -2245,12 +2262,9 @@ async def _prepare_batched( # Z pre-positioning idle_channels = sorted(set(range(self.num_channels)) - set(use_channels)) if min_traverse_height_at_beginning_of_command is not None: - await asyncio.gather( - *[ - self.move_channel_stop_disk_z(channel_idx=ch_idx, z=self.MAXIMUM_CHANNEL_Z_POSITION) - for ch_idx in idle_channels - ] - ) + async with anyio.create_task_group() as tg: + for ch_idx in idle_channels: + tg.start_soon(self.move_channel_stop_disk_z, ch_idx, self.MAXIMUM_CHANNEL_Z_POSITION) await self.position_channels_in_z_direction( {ch: min_traverse_height_at_beginning_of_command for ch in use_channels} ) @@ -2309,23 +2323,27 @@ def _detect_func(mode: "STARBackend.LLDMode") -> Callable[..., Any]: measurements: Dict[int, List[Optional[float]]] = {orig_idx: [] for orig_idx in batch.indices} for _ in range(n_replicates): - results = await asyncio.gather( - *[ - _detect_func(lld_mode[orig_idx])( - channel_idx=channel, - lowest_immers_pos=lip, - start_pos_search=sps, - channel_speed=search_speed, - ) - for channel, lip, sps, orig_idx in zip( - batch.channels, batch_lowest_immers, batch_start_pos, batch.indices - ) - ], - return_exceptions=True, - ) + errors: List[Optional[Exception]] = [None] * len(batch.channels) + async with anyio.create_task_group() as tg: + for idx, (channel, lip, sps, orig_idx) in enumerate( + zip(batch.channels, batch_lowest_immers, batch_start_pos, batch.indices) + ): + + async def worker(i=idx, ch=channel, lip=lip, sps=sps, orig_idx=orig_idx): + try: + await _detect_func(lld_mode[orig_idx])( + channel_idx=ch, + lowest_immers_pos=lip, + start_pos_search=sps, + channel_speed=search_speed, + ) + except Exception as e: + errors[i] = e + + tg.start_soon(worker) current_absolute_liquid_heights = await self.request_pip_height_last_lld() - for local_idx, (ch_idx, result) in enumerate(zip(batch.channels, results)): + for local_idx, (ch_idx, result) in enumerate(zip(batch.channels, errors)): orig_idx = batch.indices[local_idx] if isinstance(result, STARFirmwareError): error_msg = str(result).lower() @@ -2719,19 +2737,17 @@ async def empty_tips( f"channel_idx must be between 0 and {self.num_channels - 1}, got {channels}" ) - await asyncio.gather( - *[ - self.empty_tip( - channel_idx=ch, - vol=vol, - flow_rate=flow_rate, - acceleration=acceleration, - current_limit=current_limit, - reset_dispensing_drive_after=reset_dispensing_drive_after, + async with anyio.create_task_group() as tg: + for ch in channels: + tg.start_soon( + self.empty_tip, + ch, + vol, + flow_rate, + acceleration, + current_limit, + reset_dispensing_drive_after, ) - for ch in channels - ] - ) # # # Channel Liquid Handling Commands # # # @@ -3631,24 +3647,26 @@ async def _core96_wait_for_idle(self, timeout: float = 600, poll_interval: float with H0 CommandSyntaxError trace 40 ("No parallel processes permitted"). When the head finishes, EV succeeds and harmlessly ensures the Z axis is at the safe position. """ - start = asyncio.get_event_loop().time() - while asyncio.get_event_loop().time() - start < timeout: - await asyncio.sleep(poll_interval) - try: - await self.send_command(module="C0", command="EV", read_timeout=10) - logger.info("CoRe 96 head finished (EV succeeded)") - return - except STARFirmwareError as e: - h0_error = e.errors.get("CoRe 96 Head") - if ( - h0_error is not None - and isinstance(h0_error, CommandSyntaxError) - and h0_error.trace_information == 40 - ): - logger.debug("CoRe 96 head still busy, waiting...") - continue - raise - raise TimeoutError("CoRe 96 head did not become idle within timeout") + try: + with anyio.fail_after(timeout): + while True: + await anyio.sleep(poll_interval) + try: + await self.send_command(module="C0", command="EV", read_timeout=10) + logger.info("CoRe 96 head finished (EV succeeded)") + return + except STARFirmwareError as e: + h0_error = e.errors.get("CoRe 96 Head") + if ( + h0_error is not None + and isinstance(h0_error, CommandSyntaxError) + and h0_error.trace_information == 40 + ): + logger.debug("CoRe 96 head still busy, waiting...") + continue + raise + except TimeoutError: + raise TimeoutError("CoRe 96 head did not become idle within timeout") from None @_requires_head96 async def aspirate96( @@ -5565,7 +5583,6 @@ async def request_machine_configuration(self) -> MachineConfiguration: Returns the basic machine configuration including configuration data 1 (kb) and number of PIP channels (kp). """ - resp = await self.send_command(module="C0", command="RM", fmt="kb**kp##") kb = resp["kb"] return MachineConfiguration( @@ -9510,7 +9527,7 @@ async def verify_and_wait_for_carriers( await self.set_loading_indicators(bit_pattern[::-1], blink_pattern[::-1]) # Wait before checking again - await asyncio.sleep(check_interval) + await anyio.sleep(check_interval) # Check for presence again detected_rails = set(await self.request_presence_of_carriers_on_deck()) diff --git a/pylabrobot/liquid_handling/backends/hamilton/STAR_chatterbox.py b/pylabrobot/liquid_handling/backends/hamilton/STAR_chatterbox.py index 28624096ba8..5125494938d 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/STAR_chatterbox.py +++ b/pylabrobot/liquid_handling/backends/hamilton/STAR_chatterbox.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from typing import Dict, List, Literal, Optional, Union -from pylabrobot.liquid_handling.backends import LiquidHandlerBackend +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.backends.hamilton.STAR_backend import ( DriveConfiguration, ExtendedConfiguration, @@ -100,13 +100,15 @@ def __init__( extended_configuration.min_raster_pitch_pip_channels ] * num_channels - async def setup( + async def _enter_lifespan( self, - skip_instrument_initialization=False, - skip_pip=False, - skip_autoload=False, - skip_iswap=False, - skip_core96_head=False, + stack: AsyncExitStackWithShielding, + *, + skip_instrument_initialization: bool = False, + skip_pip: bool = False, + skip_autoload: bool = False, + skip_iswap: bool = False, + skip_core96_head: bool = False, ): """Initialize the chatterbox backend and detect installed modules. @@ -117,7 +119,14 @@ async def setup( skip_iswap: If True, skip initializing the iSWAP module, if applicable. skip_core96_head: If True, skip initializing the CoRe 96 head module, if applicable. """ - await LiquidHandlerBackend.setup(self) + await super()._enter_lifespan( + stack, + skip_instrument_initialization=skip_instrument_initialization, + skip_pip=skip_pip, + skip_autoload=skip_autoload, + skip_iswap=skip_iswap, + skip_core96_head=skip_core96_head, + ) self.id_ = 0 @@ -137,10 +146,6 @@ async def setup( else: self._head96_information = None - async def stop(self): - await LiquidHandlerBackend.stop(self) - self._setup_done = False - # # # # # # # # Low-level command sending/receiving # # # # # # # # async def _write_and_read_command( diff --git a/pylabrobot/liquid_handling/backends/hamilton/STAR_tests.py b/pylabrobot/liquid_handling/backends/hamilton/STAR_tests.py index 2bb9a8b0152..e7514ca31b4 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/STAR_tests.py +++ b/pylabrobot/liquid_handling/backends/hamilton/STAR_tests.py @@ -5,6 +5,7 @@ import unittest.mock from typing import Literal, cast +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling import LiquidHandler from pylabrobot.liquid_handling.standard import GripDirection, Pickup from pylabrobot.plate_reading import PlateReader @@ -30,9 +31,12 @@ from pylabrobot.resources.barcode import Barcode from pylabrobot.resources.greiner import Greiner_384_wellplate_28ul_Fb from pylabrobot.resources.hamilton import STARLetDeck, hamilton_96_tiprack_300uL_filter +from pylabrobot.testing.concurrency import AnyioTestBase, lifespan_kwargs +from pylabrobot.testing.mock_io import MockIO from .STAR_backend import ( CommandSyntaxError, + HamiltonLiquidHandler, HamiltonNoTipError, HardwareError, STARBackend, @@ -50,7 +54,6 @@ class TestSTARResponseParsing(unittest.TestCase): """Test parsing of response from Hamilton.""" def setUp(self): - super().setUp() self.star = STARBackend() def test_parse_response_params(self): @@ -147,14 +150,19 @@ def _any_write_and_read_command_call(cmd): ) -class TestSTARUSBComms(unittest.IsolatedAsyncioTestCase): +class TestSTARUSBComms(AnyioTestBase): """Test that USB data is parsed correctly.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.star = STARBackend(read_timeout=1, packet_read_timeout=1) self.star.set_deck(STARLetDeck()) - self.star.io = unittest.mock.AsyncMock() - await super().asyncSetUp() + self.star.io = MockIO() # type: ignore + # We need to temporarily replace _enter_lifespan with one that forwards to the parent class, + # so as not to do any hardware setup on enter, but still start the reader loop. + self.star._enter_lifespan = lambda stack, **kwargs: HamiltonLiquidHandler._enter_lifespan( + self.star, stack + ) + await stack.enter_async_context(self.star) async def test_send_command_correct_response(self): self.star.io.read.side_effect = [b"C0QMid0001"] @@ -180,11 +188,19 @@ def __init__(self): super().__init__() self.commands = [] - async def setup(self) -> None: # type: ignore + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, **kwargs) -> None: + # Bypass STARBackend._enter_lifespan to avoid sending commands to mock machine. self._num_channels = 8 self._machine_conf = _DEFAULT_MACHINE_CONFIGURATION self._extended_conf = _DEFAULT_EXTENDED_CONFIGURATION self._core_parked = True + self._setup_done = True + + def cleanup(): + self.stop_finished = True + self._setup_done = False + + stack.callback(cleanup) async def send_command( # type: ignore self, @@ -202,23 +218,18 @@ async def send_command( # type: ignore ) self.commands.append(cmd) - async def stop(self): - self.stop_finished = True - -class TestSTARLiquidHandlerCommands(unittest.IsolatedAsyncioTestCase): +class TestSTARLiquidHandlerCommands(AnyioTestBase): """Test STAR backend for liquid handling.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack, *, with_lh=True): self.STAR = STARBackend(read_timeout=1) self.STAR._write_and_read_command = unittest.mock.AsyncMock() - self.STAR.io = unittest.mock.AsyncMock() - self.STAR.io.setup = unittest.mock.AsyncMock() - self.STAR.io.write = unittest.mock.MagicMock() - self.STAR.io.read = unittest.mock.MagicMock() + self.STAR.io = MockIO() # type: ignore self.deck = STARLetDeck() - self.lh = LiquidHandler(self.STAR, deck=self.deck) + if with_lh: + self.lh = LiquidHandler(self.STAR, deck=self.deck) self.tip_car = TIP_CAR_480_A00(name="tip carrier") self.tip_car[1] = self.tip_rack = hamilton_96_tiprack_300uL_filter(name="tip_rack_01") @@ -267,12 +278,21 @@ def __init__(self, name: str): self.STAR._num_channels = 8 self.STAR._machine_conf = _DEFAULT_MACHINE_CONFIGURATION self.STAR._extended_conf = _DEFAULT_EXTENDED_CONFIGURATION - self.STAR.setup = unittest.mock.AsyncMock() + self.STAR._core_parked = True self.STAR._iswap_parked = True - await self.lh.setup() + + # Bypass hardware initialization in _enter_lifespan + self.STAR._enter_lifespan = lambda stack, **kwargs: HamiltonLiquidHandler._enter_lifespan( + self.STAR, stack + ) + + if with_lh: + await stack.enter_async_context(self.lh) set_tip_tracking(enabled=False) + self.STAR._write_and_read_command.reset_mock() + self.STAR.id_ = 0 async def test_core_read_barcode_success(self): """core_read_barcode_of_picked_up_resource should send ZB and return a Barcode.""" @@ -382,9 +402,6 @@ async def test_core_read_barcode_manual_input_empty_raises_value_error(self): labware_description="Cos_96_PCR_0001", ) - async def asyncTearDown(self): - await self.lh.stop() - async def test_indicator_light(self): await self.STAR.set_loading_indicators(bit_pattern=[True] * 54, blink_pattern=[False] * 54) self.STAR._write_and_read_command.assert_has_calls( @@ -395,7 +412,7 @@ async def test_indicator_light(self): ] ) - def test_ops_to_fw_positions(self): + async def test_ops_to_fw_positions(self): """Convert channel positions to firmware positions.""" tip_a1 = self.tip_rack.get_item("A1") tip_f1 = self.tip_rack.get_item("F1") @@ -475,7 +492,7 @@ async def test_tip_pickup_56(self): self.STAR.io.write.reset_mock() async def test_tip_drop_56(self): - await self.test_tip_pickup_56() # pick up tips first + await self.test_tip_pickup_56.original_func(self) # pick up tips first self.STAR._write_and_read_command.side_effect = [ "C0TRid0003kz000 000 000 000 000 000 000 000vz000 000 000 000 000 000 000 000" ] @@ -491,7 +508,7 @@ async def test_tip_drop_56(self): async def test_aspirate56(self): self.maxDiff = None - await self.test_tip_pickup_56() # pick up tips first + await self.test_tip_pickup_56.original_func(self) # pick up tips first assert self.plate.lid is not None self.plate.lid.unassign() for well in self.plate.get_items(["A1", "B1"]): @@ -1019,6 +1036,7 @@ async def test_discard_tips(self): ] ) + @lifespan_kwargs(with_lh=False) async def test_portrait_tip_rack_handling(self): deck = STARLetDeck() lh = LiquidHandler(self.STAR, deck=deck) @@ -1027,24 +1045,23 @@ async def test_portrait_tip_rack_handling(self): assert tr.rotation.z == 90 assert tr.location == Coordinate(82.6, 0, 0) deck.assign_child_resource(tip_car, rails=2) - await lh.setup() - - await lh.pick_up_tips(tr["A4:A1"]) - self.STAR._write_and_read_command.side_effect = [ - "C0TRid0002kz000 000 000 000 000 000 000 000vz000 000 000 000 000 000 000 000" - ] - await lh.drop_tips(tr["A4:A1"]) - - self.STAR._write_and_read_command.assert_has_calls( - [ - _any_write_and_read_command_call( - "C0TPid0002xp01360 01360 01360 01360 00000&yp1380 1290 1200 1110 0000&tm1 1 1 1 0&tt01tp2263tz2163th2450td0" - ), - _any_write_and_read_command_call( - "C0TRid0003xp01360 01360 01360 01360 00000&yp1380 1290 1200 1110 0000&tm1 1 1 1 0&tp2263tz2183th2450te2450ti1" - ), + async with lh: + await lh.pick_up_tips(tr["A4:A1"]) + self.STAR._write_and_read_command.side_effect = [ + "C0TRid0002kz000 000 000 000 000 000 000 000vz000 000 000 000 000 000 000 000" ] - ) + await lh.drop_tips(tr["A4:A1"]) + + self.STAR._write_and_read_command.assert_has_calls( + [ + _any_write_and_read_command_call( + "C0TPid0002xp01360 01360 01360 01360 00000&yp1380 1290 1200 1110 0000&tm1 1 1 1 0&tt01tp2263tz2163th2450td0" + ), + _any_write_and_read_command_call( + "C0TRid0003xp01360 01360 01360 01360 00000&yp1380 1290 1200 1110 0000&tm1 1 1 1 0&tp2263tz2183th2450te2450ti1" + ), + ] + ) def test_serialize(self): serialized = LiquidHandler(backend=STARBackend(), deck=STARLetDeck()).serialize() @@ -1081,8 +1098,8 @@ async def test_move_core(self): ) -class STARIswapMovementTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): +class STARIswapMovementTests(AnyioTestBase): + async def _enter_lifespan(self, stack): self.STAR = STARBackend() self.STAR._write_and_read_command = unittest.mock.AsyncMock() self.deck = STARLetDeck() @@ -1098,10 +1115,15 @@ async def asyncSetUp(self): self.STAR._num_channels = 8 self.STAR._machine_conf = _DEFAULT_MACHINE_CONFIGURATION self.STAR._extended_conf = _DEFAULT_EXTENDED_CONFIGURATION - self.STAR.setup = unittest.mock.AsyncMock() + + async def mock_enter_lifespan(stack, **kwargs): + pass + + self.STAR._enter_lifespan = mock_enter_lifespan + self.STAR._core_parked = True self.STAR._iswap_parked = True - await self.lh.setup() + await stack.enter_async_context(self.lh) async def test_simple_movement(self): await self.lh.move_plate(self.plate, self.plt_car[1]) @@ -1208,8 +1230,8 @@ async def test_move_lid_across_rotated_resources(self): ) -class STARFoilTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): +class STARFoilTests(AnyioTestBase): + async def _enter_lifespan(self, stack): self.star = STARBackend() self.star._write_and_read_command = unittest.mock.AsyncMock() self.deck = STARLetDeck() @@ -1227,11 +1249,15 @@ async def asyncSetUp(self): self.star._num_channels = 8 self.star._machine_conf = _DEFAULT_MACHINE_CONFIGURATION self.star._extended_conf = _DEFAULT_EXTENDED_CONFIGURATION - self.star.setup = unittest.mock.AsyncMock() + + async def mock_enter_lifespan(stack, **kwargs): + pass + + self.star._enter_lifespan = mock_enter_lifespan + self.star._core_parked = True self.star._iswap_parked = True - await self.lh.setup() - + await stack.enter_async_context(self.lh) await self.lh.pick_up_tips(self.tip_rack["A1:H1"]) async def test_pierce_foil_wide(self): @@ -1413,17 +1439,22 @@ async def test_pierce_foil_portrait_tight(self): ) -class TestSTARTipPickupDropAllSizes(unittest.IsolatedAsyncioTestCase): +class TestSTARTipPickupDropAllSizes(AnyioTestBase): """Test STAR tip pickup and drop Z position calculations for all tip sizes.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.backend = STARBackend() self.backend._write_and_read_command = unittest.mock.AsyncMock() - self.backend.io = unittest.mock.AsyncMock() + self.backend.io = MockIO() # type: ignore self.backend._num_channels = 8 self.backend._machine_conf = _DEFAULT_MACHINE_CONFIGURATION self.backend._extended_conf = _DEFAULT_EXTENDED_CONFIGURATION - self.backend.setup = unittest.mock.AsyncMock() + + async def mock_enter_lifespan(stack, **kwargs): + pass + + self.backend._enter_lifespan = mock_enter_lifespan + self.backend._core_parked = True self.backend._iswap_parked = True @@ -1433,7 +1464,7 @@ async def asyncSetUp(self): self.tip_car = TIP_CAR_480_A00(name="tip_carrier") self.deck.assign_child_resource(self.tip_car, rails=1) - await self.lh.setup() + await stack.enter_async_context(self.lh) set_tip_tracking(enabled=False) def _get_tp_tz_from_calls(self, cmd_prefix: str): @@ -1534,7 +1565,7 @@ async def test_1000uL_tips(self): tip_rack.unassign() -class TestChannelsMinimumYSpacing(unittest.IsolatedAsyncioTestCase): +class TestChannelsMinimumYSpacing(AnyioTestBase): """Test that different channel spacing configurations produce different behavior. Real firmware VY responses captured from hardware (GitHub issue #822): @@ -1542,6 +1573,9 @@ class TestChannelsMinimumYSpacing(unittest.IsolatedAsyncioTestCase): - 8-channel 9mm standard: PVYidyc000 194 0 (yc[1]=194 → 9.0mm) """ + async def _enter_lifespan(self, stack): + pass + # -- can_reach_position: reachability shrinks with wider spacing ---------------- async def test_can_reach_4ch_18mm_rejects_position_reachable_at_9mm(self): @@ -1636,16 +1670,13 @@ async def test_position_channels_make_space_spreads_wider_at_18mm(self): self.assertNotEqual(cmd_9mm, cmd_18mm) -class TestProbeLiquidHeights(unittest.IsolatedAsyncioTestCase): - """Tests for probe_liquid_heights: detection dispatch, replicates, error handling.""" +class STARTestBase(AnyioTestBase): + """Shared setup for probe/batch/helper tests.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.STAR = STARBackend(read_timeout=1) self.STAR._write_and_read_command = unittest.mock.AsyncMock() - self.STAR.io = unittest.mock.AsyncMock() - self.STAR.io.setup = unittest.mock.AsyncMock() - self.STAR.io.write = unittest.mock.MagicMock() - self.STAR.io.read = unittest.mock.MagicMock() + self.STAR.io = MockIO() # type: ignore self.deck = STARLetDeck() self.lh = LiquidHandler(self.STAR, deck=self.deck) @@ -1661,15 +1692,21 @@ async def asyncSetUp(self): self.STAR._num_channels = 8 self.STAR._machine_conf = _DEFAULT_MACHINE_CONFIGURATION self.STAR._extended_conf = _DEFAULT_EXTENDED_CONFIGURATION - self.STAR.setup = unittest.mock.AsyncMock() + + async def mock_enter_lifespan(stack, **kwargs): + pass + + self.STAR._enter_lifespan = mock_enter_lifespan + self.STAR._core_parked = True self.STAR._iswap_parked = True - await self.lh.setup() + await stack.enter_async_context(self.lh) set_tip_tracking(enabled=False) - async def asyncTearDown(self): - await self.lh.stop() + +class TestProbeLiquidHeights(STARTestBase): + """Tests for probe_liquid_heights: detection dispatch, replicates, error handling.""" def _put_tips_on_channels(self, channels): tip = self.tip_rack.get_tip("A1") diff --git a/pylabrobot/liquid_handling/backends/hamilton/base.py b/pylabrobot/liquid_handling/backends/hamilton/base.py index 73ff83be6f7..235bdb219a0 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/base.py +++ b/pylabrobot/liquid_handling/backends/hamilton/base.py @@ -1,8 +1,5 @@ -import asyncio import datetime import logging -import threading -import time import warnings from abc import ABCMeta, abstractmethod from dataclasses import dataclass @@ -13,8 +10,12 @@ Sequence, Tuple, TypeVar, + Union, ) +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding, MachineConnectionClosedError from pylabrobot.io.usb import USB from pylabrobot.liquid_handling.backends.backend import ( LiquidHandlerBackend, @@ -37,10 +38,9 @@ class HamiltonTask: """A command that has been sent, awaiting a response.""" id_: Optional[int] - loop: asyncio.AbstractEventLoop - fut: asyncio.Future cmd: str - timeout_time: float + done_event: anyio.Event + response: Optional[Union[str, Exception]] class HamiltonLiquidHandler(LiquidHandlerBackend, metaclass=ABCMeta): @@ -83,9 +83,9 @@ def __init__( self.id_ = 0 - self._reading_thread: Optional[threading.Thread] = None - self._reading_thread_stop = threading.Event() - self._waiting_tasks: List[HamiltonTask] = [] + self._wakeup_reader_loop: Optional[anyio.Event] = None + self._waiting_tasks_with_id: dict[int, HamiltonTask] = {} + self._waiting_tasks_idless: dict[str, list[HamiltonTask]] = {} self._tth2tti: dict[int, int] = {} # hash to tip type index def __setattr__(self, name: str, value: Any) -> None: @@ -99,25 +99,28 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) - async def setup(self): - await super().setup() - await self.io.setup() - self._reading_thread_stop.clear() - self._reading_thread = threading.Thread(target=self._reading_thread_main, daemon=True) - self._reading_thread.start() - - async def stop(self): - self._reading_thread_stop.set() - if self._reading_thread is not None: - self._reading_thread.join(timeout=10) - self._reading_thread = None - for task in self._waiting_tasks: - task.loop.call_soon_threadsafe( - task.fut.set_exception, RuntimeError("Stopping HamiltonLiquidHandler.") - ) - self._waiting_tasks.clear() - self._tth2tti.clear() - await self.io.stop() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) + + # Put cleanup on the stack before the task group; This way, + # by the time we get here, the reader task has completed and done its cleanup. + @stack.callback + def cleanup(): + self._wakeup_reader_loop = None + self._tth2tti.clear() + if self._waiting_tasks_with_id or self._waiting_tasks_idless: + warnings.warn( + "Internal problem: At this point, all waiting tasks should have been cleaned up!" + ) + self._waiting_tasks_with_id.clear() + self._waiting_tasks_idless.clear() + + self._wakeup_reader_loop = anyio.Event() + tg = await stack.enter_async_context(anyio.create_task_group()) + # Put canceling the reader loop on top of the stack; it goes first + stack.callback(tg.cancel_scope.cancel) + tg.start_soon(self._continuously_read) def serialize(self) -> dict: usb_serialized = self.io.serialize() @@ -255,7 +258,6 @@ async def send_command( Returns: A dictionary containing the parsed response, or None if no response was read within `timeout`. """ - cmd, id_ = self._assemble_command( module=module, command=command, @@ -283,40 +285,49 @@ async def _write_and_read_command( wait: bool = True, ) -> Optional[str]: """Write a command to the Hamilton machine and read the response.""" - await self.io.write(cmd.encode(), timeout=write_timeout) - if not wait: + await self.io.write(cmd.encode(), timeout=write_timeout) return None - # Attempt to read packets until timeout, or when we identify the right id. - if read_timeout is None: - read_timeout = self.read_timeout - - loop = asyncio.get_event_loop() - fut: asyncio.Future[str] = loop.create_future() - self._start_reading(id_, loop, fut, cmd, read_timeout) - result = await fut - return result - - def _start_reading( - self, - id_: Optional[int], - loop: asyncio.AbstractEventLoop, - fut: asyncio.Future, - cmd: str, - timeout: int, - ) -> None: - """Submit a task to the reading thread.""" - - timeout_time = time.time() + timeout - self._waiting_tasks.append( - HamiltonTask(id_=id_, loop=loop, fut=fut, cmd=cmd, timeout_time=timeout_time) - ) - - if self._reading_thread is None or not self._reading_thread.is_alive(): - self._reading_thread_stop.clear() - self._reading_thread = threading.Thread(target=self._reading_thread_main, daemon=True) - self._reading_thread.start() + done_evt = anyio.Event() + task = HamiltonTask(id_=id_, cmd=cmd, done_event=done_evt, response=None) + cmd_prefix = cmd[: self.module_id_length + 2] + try: + idle = not (self._waiting_tasks_with_id or self._waiting_tasks_idless) + if id_ is None: + # TODO: Do we want to allow multiple id-less tasks to be sent? + self._waiting_tasks_idless.setdefault(cmd_prefix, []).append(task) + else: + if self._waiting_tasks_with_id.setdefault(id_, task) is not task: + raise RuntimeError("Another task with this ID is already pending") + if idle: + assert self._wakeup_reader_loop is not None + self._wakeup_reader_loop.set() + await self.io.write(cmd.encode(), timeout=write_timeout) + + # Attempt to read packets until timeout, or when we identify the right id. + if read_timeout is None: + read_timeout = self.read_timeout + + with anyio.fail_after(read_timeout): + await done_evt.wait() + finally: + # reader loop atomically removes tasks from waiting lists and sets the response, + # so we have to remove us from the waiting list exactly iff we don't have a response at this point. + if task.response is None: + if id_ is None: + self._waiting_tasks_idless[cmd_prefix].remove(task) + else: + del self._waiting_tasks_with_id[id_] + + assert task.response is not None + + if isinstance(task.response, Exception): + # An error occurred in the reader loop. + raise task.response + + self.check_fw_string_error(task.response) + return task.response @abstractmethod def get_id_from_fw_response(self, resp: str) -> Optional[int]: @@ -330,13 +341,8 @@ def check_fw_string_error(self, resp: str): def _parse_response(self, resp: str, fmt: Any) -> dict: """Parse a firmware response.""" - def _reading_thread_main(self) -> None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._continuously_read()) - async def _continuously_read(self) -> None: - """Continuously read from the USB port until stop is requested. + """Continuously read from the USB port until cancelled. Tasks are stored in the `self._waiting_tasks` list, and contain a future that will be completed when the task is finished. Tasks are submitted to the list using the @@ -346,52 +352,55 @@ async def _continuously_read(self) -> None: relevant to any of the tasks. If so, complete the future and remove the task from the list. If a task has timed out, complete the future with a `TimeoutError`. """ + try: + while True: + if not (self._waiting_tasks_with_id or self._waiting_tasks_idless): + assert self._wakeup_reader_loop is not None + await self._wakeup_reader_loop.wait() + self._wakeup_reader_loop = anyio.Event() + continue - while not self._reading_thread_stop.is_set(): - for idx in range(len(self._waiting_tasks) - 1, -1, -1): # reverse order to allow deletion - task = self._waiting_tasks[idx] - if time.time() > task.timeout_time: - logger.warning("Timeout while waiting for response to command %s.", task.cmd) - task.loop.call_soon_threadsafe( - task.fut.set_exception, - TimeoutError(f"Timeout while waiting for response to command {task.cmd}."), - ) - del self._waiting_tasks[idx] - - if len(self._waiting_tasks) == 0: - await asyncio.sleep(0.01) - continue - - try: - resp = (await self.io.read()).decode("utf-8") - except TimeoutError: - continue - - if resp == "": - continue - - # Parse response. - try: - response_id = self.get_id_from_fw_response(resp) - except ValueError as e: - logger.warning("Could not parse response: %s (%s)", resp, e) - continue - - module_and_command = resp[: self.module_id_length + 2] - for idx in range(len(self._waiting_tasks)): - task = self._waiting_tasks[idx] - # if the command has no id, we have to check the command itself - if response_id == task.id_ or ( - task.id_ is None and task.cmd.startswith(module_and_command) - ): - try: - self.check_fw_string_error(resp) - except Exception as e: - task.loop.call_soon_threadsafe(task.fut.set_exception, e) - else: - task.loop.call_soon_threadsafe(task.fut.set_result, resp) - del self._waiting_tasks[idx] - break + try: + resp = (await self.io.read()).decode("utf-8") + except TimeoutError: + continue + + if resp == "": + continue + + # Parse response. + try: + response_id = self.get_id_from_fw_response(resp) + except ValueError as e: + logger.warning("Could not parse response: %s (%s)", resp, e) + continue + + cmd_prefix = resp[: self.module_id_length + 2] + task = None + if response_id is not None: + task = self._waiting_tasks_with_id.pop(response_id, None) + if task is None: + tasks = self._waiting_tasks_idless.get(cmd_prefix) + if tasks: + task = tasks.pop(0) + if not tasks: + del self._waiting_tasks_idless[cmd_prefix] + if task is not None: + task.response = resp + task.done_event.set() + else: + logger.warning("Received response for unknown command: %s", resp) + finally: + # Abort all remaining tasks + for task in self._waiting_tasks_with_id.values(): + task.response = MachineConnectionClosedError() + task.done_event.set() + for tasks in self._waiting_tasks_idless.values(): + for task in tasks: + task.response = MachineConnectionClosedError() + task.done_event.set() + self._waiting_tasks_with_id.clear() + self._waiting_tasks_idless.clear() def _ops_to_fw_positions( self, ops: Sequence[PipettingOp], use_channels: List[int] diff --git a/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend.py b/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend.py index 7018e22590b..682e48488c5 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend.py +++ b/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend.py @@ -10,6 +10,7 @@ import logging from typing import Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.backends.hamilton.common import fill_in_defaults from pylabrobot.liquid_handling.backends.hamilton.tcp.commands import HamiltonCommand from pylabrobot.liquid_handling.backends.hamilton.tcp.introspection import ( @@ -954,7 +955,13 @@ def __init__( self._channel_traversal_height: float = 146.0 # Default traversal height in mm - async def setup(self, unlock_door: bool = False, force_initialize: bool = False): + async def _enter_lifespan( + self, + stack: AsyncExitStackWithShielding, + *, + unlock_door: bool = False, + force_initialize: bool = False, + ): """Set up the Nimbus backend. This method: @@ -972,7 +979,7 @@ async def setup(self, unlock_door: bool = False, force_initialize: bool = False) force_initialize: If True, force initialization even if already initialized """ # Call parent setup (TCP connection, Protocol 7 init, Protocol 3 registration) - await super().setup() + await super()._enter_lifespan(stack) # Discover instrument objects await self._discover_instrument_objects() @@ -1244,10 +1251,6 @@ async def unlock_door(self) -> None: logger.error(f"Failed to unlock door: {e}") raise - async def stop(self): - """Stop the backend and close connection.""" - await HamiltonTCPBackend.stop(self) - async def request_tip_presence(self) -> List[Optional[bool]]: """Request tip presence on each channel. diff --git a/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend_tests.py b/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend_tests.py index 75c71692f91..ec0fb246323 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend_tests.py +++ b/pylabrobot/liquid_handling/backends/hamilton/nimbus_backend_tests.py @@ -45,6 +45,7 @@ from pylabrobot.resources.hamilton import HamiltonTip, TipPickupMethod, TipSize from pylabrobot.resources.hamilton.nimbus_decks import NimbusDeck from pylabrobot.resources.hamilton.tip_racks import hamilton_96_tiprack_300uL +from pylabrobot.testing.concurrency import AnyioTestBase class TestNimbusTipType(unittest.TestCase): @@ -530,7 +531,7 @@ def test_dispense_parameters(self): self.assertEqual(results[13][1], [1000, 0, 0, 0, 0, 0, 0, 0]) -class TestNimbusBackendUnit(unittest.IsolatedAsyncioTestCase): +class TestNimbusBackendUnit(AnyioTestBase): """Unit tests for NimbusBackend class (no actual connection).""" async def test_backend_init(self): @@ -620,10 +621,10 @@ def _setup_backend_with_deck(deck: NimbusDeck) -> NimbusBackend: return backend -class TestNimbusBackendCommands(unittest.IsolatedAsyncioTestCase): +class TestNimbusBackendCommands(AnyioTestBase): """Tests for NimbusBackend command methods.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.backend = _setup_backend() self.mock_send = unittest.mock.AsyncMock(side_effect=_mock_send_command_response) self.backend.send_command = self.mock_send # type: ignore[method-assign] @@ -674,7 +675,7 @@ async def test_park_without_address_raises(self): await self.backend.park() -class TestNimbusBackendSerialization(unittest.IsolatedAsyncioTestCase): +class TestNimbusBackendSerialization(AnyioTestBase): """Tests for NimbusBackend serialization.""" async def test_serialize(self): @@ -686,10 +687,10 @@ async def test_serialize(self): self.assertIn("instrument_addresses", serialized) -class TestNimbusLiquidHandling(unittest.IsolatedAsyncioTestCase): +class TestNimbusLiquidHandling(AnyioTestBase): """Tests for NimbusBackend liquid handling command generation.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.deck = NimbusDeck() self.backend = _setup_backend_with_deck(self.deck) self.mock_send = unittest.mock.AsyncMock(side_effect=_mock_send_command_response) @@ -1058,14 +1059,14 @@ async def test_offset_applied_to_coordinates(self): self.assertEqual(x_with_offset - x_no_offset, 1000) -class TestNimbusTipPickupDropAllSizes(unittest.IsolatedAsyncioTestCase): +class TestNimbusTipPickupDropAllSizes(AnyioTestBase): """Tests for Nimbus tip pickup/drop Z positions across all tip sizes. These tests verify that the begin/end tip pickup and drop process values match the machine-validated values. """ - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.deck = NimbusDeck() self.backend = _setup_backend_with_deck(self.deck) self.mock_send = unittest.mock.AsyncMock(side_effect=_mock_send_command_response) @@ -1183,7 +1184,3 @@ async def test_1000uL_tips(self): self.assertEqual(drop_cmd.end_tip_deposit_process[0], -8350) tip_rack.unassign() - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/liquid_handling/backends/hamilton/tcp/tcp_tests.py b/pylabrobot/liquid_handling/backends/hamilton/tcp/tcp_tests.py index 80c249a424a..cfd5bb45678 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/tcp/tcp_tests.py +++ b/pylabrobot/liquid_handling/backends/hamilton/tcp/tcp_tests.py @@ -981,7 +981,3 @@ def test_hamilton_data_type_values(self): self.assertEqual(HamiltonDataType.STRING, 15) self.assertEqual(HamiltonDataType.BOOL, 23) self.assertEqual(HamiltonDataType.I32_ARRAY, 27) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/liquid_handling/backends/hamilton/tcp_backend.py b/pylabrobot/liquid_handling/backends/hamilton/tcp_backend.py index 9c6a9acbb13..8ab401ba305 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/tcp_backend.py +++ b/pylabrobot/liquid_handling/backends/hamilton/tcp_backend.py @@ -6,11 +6,11 @@ from __future__ import annotations -import asyncio import logging from dataclasses import dataclass from typing import Dict, Optional, Union +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.binary import Reader from pylabrobot.io.socket import Socket from pylabrobot.liquid_handling.backends.backend import LiquidHandlerBackend @@ -144,19 +144,11 @@ async def _reconnect(self): f"{self.io._unique_id} Reconnection attempt {attempt + 1}/{self.max_reconnect_attempts}" ) - # Clean up existing connection - try: - await self.stop() - except Exception: - pass + # Attempt to reconnect with wait time between disconnect and connect + wait_time = 1.0 * (2 ** (attempt - 1)) if attempt > 0 else 0.0 + await self.io.reconnect(wait_time=wait_time) - # Wait before reconnecting (exponential backoff) - if attempt > 0: - wait_time = 1.0 * (2 ** (attempt - 1)) # 1s, 2s, 4s, etc. - await asyncio.sleep(wait_time) - - # Attempt to reconnect - await self.setup() + await self._initialize_hamilton() self._reconnect_attempts = 0 logger.info(f"{self.io._unique_id} Reconnection successful") return @@ -289,7 +281,7 @@ async def _read_one_message(self) -> Union[RegistrationResponse, CommandResponse logger.warning(f"Unknown IP protocol: {ip_protocol}, attempting CommandResponse parse") return CommandResponse.from_bytes(complete_data) - async def setup(self): + async def _initialize_hamilton(self): """Initialize Hamilton connection and discover objects. Hamilton uses strict request-response protocol: @@ -298,14 +290,6 @@ async def setup(self): 3. Protocol 3 registration 4. Discover objects via Protocol 3 introspection """ - - # Step 1: Establish TCP connection - await self.io.setup() - - # Set connection state after successful connection - self._connected = True - self._reconnect_attempts = 0 - # Step 2: Initialize connection (Protocol 7) await self._initialize_connection() @@ -315,6 +299,22 @@ async def setup(self): # Step 4: Discover root objects await self._discover_root() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) + + def cleanup(): + self._connected = False + logger.info("Hamilton backend stopped") + + stack.callback(cleanup) + + # Set connection state after successful connection + self._connected = True + self._reconnect_attempts = 0 + + await self._initialize_hamilton() + logger.info(f"Hamilton backend setup complete. Client ID: {self._client_id}") async def _initialize_connection(self): @@ -571,16 +571,6 @@ async def send_command(self, command: HamiltonCommand, timeout: float = 10.0) -> return command.interpret_response(response_message) - async def stop(self): - """Stop the backend and close connection.""" - try: - await self.io.stop() - except Exception as e: - logger.warning(f"Error during stop: {e}") - finally: - self._connected = False - logger.info("Hamilton backend stopped") - def serialize(self) -> dict: """Serialize backend configuration.""" return { diff --git a/pylabrobot/liquid_handling/backends/hamilton/vantage_backend.py b/pylabrobot/liquid_handling/backends/hamilton/vantage_backend.py index a68b8375790..005e570bcc8 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/vantage_backend.py +++ b/pylabrobot/liquid_handling/backends/hamilton/vantage_backend.py @@ -1,10 +1,12 @@ -import asyncio import random import re import sys import warnings from typing import Dict, List, Optional, Sequence, Union, cast +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.backends.hamilton.base import ( HamiltonLiquidHandler, ) @@ -400,15 +402,17 @@ def _parse_response(self, resp: str, fmt: Dict[str, str]) -> dict: """Parse a firmware response.""" return parse_vantage_fw_string(resp, fmt) - async def setup( + async def _enter_lifespan( self, + stack: AsyncExitStackWithShielding, + *, skip_loading_cover: bool = False, skip_core96: bool = False, skip_ipg: bool = False, ): """Creates a USB connection and finds read/write interfaces.""" - await super().setup() + await super()._enter_lifespan(stack) tip_presences = await self.query_tip_presence() self._num_channels = len(tip_presences) @@ -5280,7 +5284,7 @@ async def disco_mode(self): random.randint(30, 100), ) await self.set_led_color("on", intensity=100, white=0, red=r, green=g, blue=b, uv=0) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) async def russian_roulette(self): """Dangerous easter egg.""" @@ -5307,7 +5311,7 @@ async def russian_roulette(self): await self.set_led_color("on", intensity=100, white=100, red=0, green=100, blue=0, uv=0) print("You won.") - await asyncio.sleep(5) + await anyio.sleep(5) await self.set_led_color( "on", intensity=100, diff --git a/pylabrobot/liquid_handling/backends/hamilton/vantage_tests.py b/pylabrobot/liquid_handling/backends/hamilton/vantage_tests.py index b7c95621fc0..725fd2dd92c 100644 --- a/pylabrobot/liquid_handling/backends/hamilton/vantage_tests.py +++ b/pylabrobot/liquid_handling/backends/hamilton/vantage_tests.py @@ -1,6 +1,7 @@ import unittest from typing import Any, List, Optional +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling import LiquidHandler from pylabrobot.liquid_handling.standard import Pickup from pylabrobot.resources import ( @@ -15,6 +16,7 @@ set_tip_tracking, ) from pylabrobot.resources.hamilton import VantageDeck +from pylabrobot.testing.concurrency import AnyioTestBase from .vantage_backend import ( VantageBackend, @@ -213,12 +215,19 @@ def __init__(self): super().__init__() self.commands = [] - async def setup(self) -> None: # type: ignore + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, **kwargs) -> None: self.setup_finished = True self._num_channels = 8 self.iswap_installed = True self._num_arms = 1 self._head96_installed = True + self._setup_done = True + + def cleanup(): + self.stop_finished = True + self._setup_done = False + + stack.callback(cleanup) async def send_command( self, @@ -237,14 +246,11 @@ async def send_command( ) self.commands.append(cmd) - async def stop(self): - self.stop_finished = True - -class TestVantageLiquidHandlerCommands(unittest.IsolatedAsyncioTestCase): +class TestVantageLiquidHandlerCommands(AnyioTestBase): """Test Vantage backend for liquid handling.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.mockVantage = VantageCommandCatcher() self.deck = VantageDeck(size=1.3) self.lh = LiquidHandler(self.mockVantage, deck=self.deck) @@ -261,13 +267,10 @@ async def asyncSetUp(self): self.maxDiff = None - await self.lh.setup() + await stack.enter_async_context(self.lh) set_tip_tracking(enabled=False) - async def asyncTearDown(self): - await self.lh.stop() - def _assert_command_in_command_buffer(self, cmd: str, should_be: bool, fmt: dict): """Assert that the given command was sent to the backend. The ordering of the parameters is not taken into account, but the values and formatting should match. The id parameter of the command @@ -361,7 +364,7 @@ async def test_tip_pickup_01(self): ) async def test_tip_drop_01(self): - await self.test_tip_pickup_01() # pick up tips first + await self.test_tip_pickup_01.original_func(self) # type: ignore # pick up tips first await self.lh.drop_tips(self.tip_rack["A1", "B1"]) self._assert_command_sent_once( "A1PMTRid013xp04329 04329 0&yp1458 1368 0&tm1 1 0&tp1414 1414&tz1314 1314&th2450 2450&" @@ -377,7 +380,7 @@ async def test_small_tip_pickup(self): ) async def test_small_tip_drop(self): - await self.test_small_tip_pickup() # pick up tips first + await self.test_small_tip_pickup.original_func(self) # type: ignore # pick up tips first await self.lh.drop_tips(self.small_tip_rack["A1"]) self._assert_command_sent_once( "A1PMTRid0012xp4329 0&yp2418 0&tp2024&tz1924&th2450&te2450&tm1 0&ts0td0&", @@ -575,10 +578,10 @@ async def test_move_plate(self): ) -class TestVantageTipPickupDropAllSizes(unittest.IsolatedAsyncioTestCase): +class TestVantageTipPickupDropAllSizes(AnyioTestBase): """Test Vantage tip pickup and drop Z position calculations for all tip sizes.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.backend = VantageCommandCatcher() self.deck = VantageDeck(size=1.3) self.lh = LiquidHandler(self.backend, deck=self.deck) @@ -586,12 +589,9 @@ async def asyncSetUp(self): self.tip_car = TIP_CAR_480_A00(name="tip_carrier") self.deck.assign_child_resource(self.tip_car, rails=18) - await self.lh.setup() + await stack.enter_async_context(self.lh) set_tip_tracking(enabled=False) - async def asyncTearDown(self): - await self.lh.stop() - def _get_tp_tz_from_commands(self, cmd_prefix: str, fmt: dict): """Extract tp and tz values from commands matching the prefix.""" for cmd in self.backend.commands: diff --git a/pylabrobot/liquid_handling/backends/opentrons_backend.py b/pylabrobot/liquid_handling/backends/opentrons_backend.py index f5e30322a9e..d1be010e64d 100644 --- a/pylabrobot/liquid_handling/backends/opentrons_backend.py +++ b/pylabrobot/liquid_handling/backends/opentrons_backend.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union, cast from pylabrobot import utils +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.backends.backend import ( LiquidHandlerBackend, ) @@ -95,7 +96,7 @@ def serialize(self) -> dict: "port": self.port, } - async def setup(self, skip_home: bool = False): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, skip_home: bool = False): # create run run_id = ot_api.runs.create() ot_api.set_run(run_id) @@ -112,30 +113,30 @@ async def setup(self, skip_home: bool = False): if not skip_home: await self.home() - @property - def num_channels(self) -> int: - return len([p for p in [self.left_pipette, self.right_pipette] if p is not None]) + @stack.callback + def cleanup(): + self._plr_name_to_load_name = {} + self._tip_racks = {} + self.left_pipette = None + self.right_pipette = None - async def stop(self): - """Cancel any active OT run, then clear labware definitions.""" - self._plr_name_to_load_name = {} - self._tip_racks = {} - self.left_pipette = None - self.right_pipette = None - - # cancel the HTTP-API run if it exists (helpful to make device available again in official Opentrons app) - run_id = getattr(ot_api, "run_id", None) - if run_id: - try: - _req.post(f"/runs/{run_id}/cancel") - except Exception: + # cancel the HTTP-API run if it exists (helpful to make device available again in official Opentrons app) + run_id = getattr(ot_api, "run_id", None) + if run_id: try: - _req.post(f"/runs/{run_id}/actions/cancel") + _req.post(f"/runs/{run_id}/cancel") except Exception: try: - _req.delete(f"/runs/{run_id}") + _req.post(f"/runs/{run_id}/actions/cancel") except Exception: - pass + try: + _req.delete(f"/runs/{run_id}") + except Exception: + pass + + @property + def num_channels(self) -> int: + return len([p for p in [self.left_pipette, self.right_pipette] if p is not None]) def get_ot_name(self, plr_resource_name: str) -> str: """Opentrons only allows names in ^[a-z0-9._]+$, but in PLR we are flexible. diff --git a/pylabrobot/liquid_handling/backends/opentrons_backend_tests.py b/pylabrobot/liquid_handling/backends/opentrons_backend_tests.py index 05ea8e2845f..c52c5456ed8 100644 --- a/pylabrobot/liquid_handling/backends/opentrons_backend_tests.py +++ b/pylabrobot/liquid_handling/backends/opentrons_backend_tests.py @@ -3,6 +3,8 @@ import pytest +from pylabrobot.testing.concurrency import AnyioTestBase + pytest.importorskip("ot_api") from pylabrobot.liquid_handling import LiquidHandler @@ -35,7 +37,7 @@ def _mock_health_get(): } -class OpentronsBackendSetupTests(unittest.IsolatedAsyncioTestCase): +class OpentronsBackendSetupTests(AnyioTestBase): """Tests for setup and stop""" @patch("ot_api.runs.create") @@ -44,8 +46,10 @@ class OpentronsBackendSetupTests(unittest.IsolatedAsyncioTestCase): @patch("ot_api.labware.add") @patch("ot_api.labware.define") @patch("ot_api.health.get") - async def test_setup( + async def _enter_lifespan( self, + stack, + *, mock_health_get, mock_define, mock_add, @@ -64,7 +68,20 @@ async def test_setup( self.backend = OpentronsOT2Backend(host="localhost", port=1338) self.lh = LiquidHandler(backend=self.backend, deck=OTDeck()) - await self.lh.setup() + + self.mock_create = mock_create + self.mock_home = mock_home + self.mock_add_mounted_pipettes = mock_add_mounted_pipettes + self.mock_define = mock_define + self.mock_add = mock_add + self.mock_health_get = mock_health_get + + await stack.enter_async_context(self.lh) + + async def test_setup(self): + self.mock_create.assert_called_once() + self.mock_home.assert_called_once() + self.mock_add_mounted_pipettes.assert_called_once() def test_serialize(self): serialized = OpentronsOT2Backend(host="localhost", port=1337).serialize() @@ -78,7 +95,7 @@ def test_serialize(self): ) -class OpentronsBackendCommandTests(unittest.IsolatedAsyncioTestCase): +class OpentronsBackendCommandTests(AnyioTestBase): """Tests Opentrons commands""" @patch("ot_api.runs.create") @@ -87,8 +104,9 @@ class OpentronsBackendCommandTests(unittest.IsolatedAsyncioTestCase): @patch("ot_api.labware.add") @patch("ot_api.labware.define") @patch("ot_api.health.get") - async def asyncSetUp( + async def _enter_lifespan( self, + stack, mock_health_get, mock_define, mock_add, @@ -108,7 +126,7 @@ async def asyncSetUp( self.backend = OpentronsOT2Backend(host="localhost", port=1338) self.deck = OTDeck() self.lh = LiquidHandler(backend=self.backend, deck=self.deck) - await self.lh.setup() + await stack.enter_async_context(self.lh) self.tip_rack = opentrons_96_filtertiprack_20ul(name="tip_rack") self.deck.assign_child_at_slot(self.tip_rack, slot=1) @@ -147,7 +165,7 @@ def assert_parameters(labware_id, well_name, pipette_id, offset_x, offset_y, off mock_drop_tip.side_effect = assert_parameters - await self.test_tip_pick_up() + await self.test_tip_pick_up.original_func(self) # type: ignore[attr-defined] await self.lh.drop_tips(self.tip_rack["A1"]) @patch("ot_api.lh.aspirate_in_place") @@ -166,7 +184,7 @@ def assert_parameters( mock_aspirate.side_effect = assert_parameters - await self.test_tip_pick_up() + await self.test_tip_pick_up.original_func(self) # type: ignore[attr-defined] self.plate.get_well("A1").tracker.set_volume(10) await self.lh.aspirate(self.plate["A1"], vols=[10]) diff --git a/pylabrobot/liquid_handling/backends/opentrons_simulator.py b/pylabrobot/liquid_handling/backends/opentrons_simulator.py index 914ba848ea3..90993b58c01 100644 --- a/pylabrobot/liquid_handling/backends/opentrons_simulator.py +++ b/pylabrobot/liquid_handling/backends/opentrons_simulator.py @@ -7,6 +7,7 @@ import logging from typing import Dict, List, Optional, Tuple +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.backends.backend import LiquidHandlerBackend from pylabrobot.liquid_handling.backends.opentrons_backend import OpentronsOT2Backend from pylabrobot.liquid_handling.standard import ( @@ -88,27 +89,28 @@ def serialize(self) -> dict: "right_pipette_name": self._right_pipette_name, } - async def setup(self, skip_home: bool = False): - await LiquidHandlerBackend.setup(self) + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, skip_home: bool = False): + # Skip OpentronsOT2Backend._enter_lifespan (requires ot_api); call grandparent directly. + await LiquidHandlerBackend._enter_lifespan(self, stack) self._setup_pipettes() logger.info( "OpentronsOT2Simulator setup: left=%s, right=%s", self._left_pipette_name, self._right_pipette_name, ) - if not skip_home: - await self.home() + + def cleanup(): + self.left_pipette = None + self.right_pipette = None + self.left_pipette_has_tip = False + self.right_pipette_has_tip = False + logger.info("OpentronsOT2Simulator stopped.") + + stack.callback(cleanup) async def home(self): logger.info("Homing (simulated).") - async def stop(self): - self.left_pipette = None - self.right_pipette = None - self.left_pipette_has_tip = False - self.right_pipette_has_tip = False - logger.info("OpentronsOT2Simulator stopped.") - def _current_channel_position(self, channel: int) -> Tuple[str, Coordinate]: pipette_id = self._pipette_id_for_channel(channel) return pipette_id, self._positions.get(pipette_id, Coordinate.zero()) diff --git a/pylabrobot/liquid_handling/backends/serializing_backend.py b/pylabrobot/liquid_handling/backends/serializing_backend.py index a227f3b9d43..6b7adc2a4e7 100644 --- a/pylabrobot/liquid_handling/backends/serializing_backend.py +++ b/pylabrobot/liquid_handling/backends/serializing_backend.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from typing import Any, Dict, List, Optional, Union, cast +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.backends.backend import ( LiquidHandlerBackend, ) @@ -43,12 +44,14 @@ async def send_command( ) -> Optional[dict]: raise NotImplementedError - async def setup(self): - await super().setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) await self.send_command(command="setup") - async def stop(self): - await self.send_command(command="stop") + async def cleanup(): + await self.send_command(command="stop") + + stack.push_shielded_async_callback(cleanup) def serialize(self) -> dict: return {**super().serialize(), "num_channels": self.num_channels} diff --git a/pylabrobot/liquid_handling/backends/serializing_backend_tests.py b/pylabrobot/liquid_handling/backends/serializing_backend_tests.py index b15dd9fd589..570322332f0 100644 --- a/pylabrobot/liquid_handling/backends/serializing_backend_tests.py +++ b/pylabrobot/liquid_handling/backends/serializing_backend_tests.py @@ -1,4 +1,3 @@ -import unittest from unittest.mock import AsyncMock from pylabrobot.liquid_handling import LiquidHandler @@ -16,21 +15,23 @@ no_volume_tracking, ) from pylabrobot.serializer import serialize +from pylabrobot.testing.concurrency import AnyioTestBase class _TestSerializingBackend(SerializingBackend): send_command = AsyncMock() -class SerializingBackendTests(unittest.IsolatedAsyncioTestCase): +class TestSerializingBackend(AnyioTestBase): """Tests for the serializing backend""" - async def asyncSetUp(self) -> None: + async def _enter_lifespan(self, stack) -> None: + await super()._enter_lifespan(stack) self.backend = _TestSerializingBackend(num_channels=8) self.backend.send_command.reset_mock() self.deck = STARLetDeck() self.lh = LiquidHandler(backend=self.backend, deck=self.deck) - await self.lh.setup() + await stack.enter_async_context(self.lh) self.tip_car = TIP_CAR_480_A00(name="tip carrier") self.tip_car[0] = self.tip_rack = hamilton_96_tiprack_300uL_filter(name="tip_rack_01") diff --git a/pylabrobot/liquid_handling/backends/tecan/EVO_backend.py b/pylabrobot/liquid_handling/backends/tecan/EVO_backend.py index c4d9d2f2a9c..de230e6cbeb 100644 --- a/pylabrobot/liquid_handling/backends/tecan/EVO_backend.py +++ b/pylabrobot/liquid_handling/backends/tecan/EVO_backend.py @@ -1,4 +1,3 @@ -import asyncio from abc import ABCMeta, abstractmethod from typing import ( Dict, @@ -10,6 +9,9 @@ Union, ) +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.usb import USB from pylabrobot.liquid_handling.backends.backend import ( LiquidHandlerBackend, @@ -165,12 +167,9 @@ async def send_command( resp = await self.io.read(timeout=read_timeout) return self.parse_response(resp) - async def setup(self): - await super().setup() - await self.io.setup() - - async def stop(self): - await self.io.stop() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) class EVOBackend(TecanLiquidHandler): @@ -261,13 +260,8 @@ def mca_connected(self) -> bool: def serialize(self) -> dict: return {**super().serialize(), **self.io.serialize()} - async def setup(self): - """Setup - - Creates a USB connection and finds read/write interfaces. - """ - - await super().setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) self._liha_connected = await self.setup_arm(EVOBackend.LIHA) self._mca_connected = await self.setup_arm(EVOBackend.MCA) @@ -333,19 +327,19 @@ async def _park_mca(self): # Ensure MCA is initialized before moving await self.send_command(EVO.MCA, command="PIA") - await asyncio.sleep(0.5) + await anyio.sleep(0.5) # Raise MCA Z-axis first to avoid collision await self.send_command(EVO.MCA, command="PAA", params=[None, None, 2000]) # Raise Z-axis - await asyncio.sleep(1) + await anyio.sleep(1) # Move MCA to parking position (adjust X, Y as needed) await self.send_command(EVO.MCA, command="PAA", params=[6000, 1000, None]) - await asyncio.sleep(1) + await anyio.sleep(1) # Stop movement to prevent drifting await self.send_command(EVO.MCA, command="BMA", params=[0, 0, 0]) - await asyncio.sleep(0.5) + await anyio.sleep(0.5) # ============== LiquidHandlerBackend methods ============== diff --git a/pylabrobot/liquid_handling/liquid_handler.py b/pylabrobot/liquid_handling/liquid_handler.py index 26d8385b055..b53ac22d023 100644 --- a/pylabrobot/liquid_handling/liquid_handler.py +++ b/pylabrobot/liquid_handling/liquid_handler.py @@ -24,6 +24,7 @@ cast, ) +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.liquid_handling.channel_positioning import ( compute_channel_offsets, ) @@ -155,15 +156,13 @@ def _resource_pickup(self) -> Optional[ResourcePickup]: def _resource_pickup(self, value: Optional[ResourcePickup]) -> None: self._resource_pickups[0] = value - async def setup(self, **backend_kwargs): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """Prepare the robot for use.""" - if self.setup_finished: - raise RuntimeError("The setup has already finished. See `LiquidHandler.stop`.") - self.backend.set_deck(self.deck) self.backend.set_heads(head=self.head, head96=self.head96) - await super().setup(**backend_kwargs) + + await super()._enter_lifespan(stack) self.head = {c: TipTracker(thing=f"Channel {c}") for c in range(self.backend.num_channels)} diff --git a/pylabrobot/liquid_handling/liquid_handler_tests.py b/pylabrobot/liquid_handling/liquid_handler_tests.py index d27eba719e2..8c06bdfaf55 100644 --- a/pylabrobot/liquid_handling/liquid_handler_tests.py +++ b/pylabrobot/liquid_handling/liquid_handler_tests.py @@ -50,6 +50,7 @@ ) from pylabrobot.resources.well import Well from pylabrobot.serializer import serialize +from pylabrobot.testing.concurrency import AnyioTestBase from .liquid_handler import LiquidHandler from .standard import ( @@ -112,11 +113,13 @@ def _make_disp( ) -class TestLiquidHandlerLayout(unittest.IsolatedAsyncioTestCase): - def setUp(self): +class TestLiquidHandlerLayout(AnyioTestBase): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.backend = _create_mock_backend(num_channels=8) self.deck = STARLetDeck() self.lh = LiquidHandler(self.backend, deck=self.deck) + await stack.enter_async_context(self.lh) def test_resource_assignment(self): tip_car = TIP_CAR_480_A00(name="tip_carrier") @@ -476,8 +479,9 @@ def test_serialize(self): ) -class TestLiquidHandlerCommands(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): +class TestLiquidHandlerCommands(AnyioTestBase): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.maxDiff = None self.backend = _create_mock_backend(num_channels=8) @@ -488,7 +492,7 @@ async def asyncSetUp(self): self.plate = Cor_96_wellplate_360ul_Fb(name="plate") self.deck.assign_child_resource(self.tip_rack, location=Coordinate(0, 0, 0)) self.deck.assign_child_resource(self.plate, location=Coordinate(100, 100, 0)) - await self.lh.setup() + await stack.enter_async_context(self.lh) async def test_offsets_tips(self): tip_spot = self.tip_rack.get_item("A1") @@ -914,19 +918,18 @@ async def custom_pick_up_tips(ops, use_channels, non_default, default=True): self.backend = _create_mock_backend(num_channels=16) self.backend.pick_up_tips = custom_pick_up_tips self.lh = LiquidHandler(self.backend, deck=self.deck) - await self.lh.setup() - - with no_tip_tracking(): - set_strictness(Strictness.IGNORE) - await self.lh.pick_up_tips(self.tip_rack["A1"], non_default=True) - await self.lh.pick_up_tips( - self.tip_rack["A1"], - use_channels=[1], - non_default=True, - does_not_exist=True, - ) - with self.assertRaises(TypeError): # missing non_default - await self.lh.pick_up_tips(self.tip_rack["A1"], use_channels=[2]) + async with self.lh: + with no_tip_tracking(): + set_strictness(Strictness.IGNORE) + await self.lh.pick_up_tips(self.tip_rack["A1"], non_default=True) + await self.lh.pick_up_tips( + self.tip_rack["A1"], + use_channels=[1], + non_default=True, + does_not_exist=True, + ) + with self.assertRaises(TypeError): # missing non_default + await self.lh.pick_up_tips(self.tip_rack["A1"], use_channels=[2]) set_strictness(Strictness.WARN) await self.lh.pick_up_tips(self.tip_rack["A1"], non_default=True, use_channels=[3]) @@ -1009,8 +1012,9 @@ async def test_pick_up_tips96_incomplete_rack(self): set_tip_tracking(enabled=False) -class TestLiquidHandlerVolumeTracking(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): +class TestLiquidHandlerVolumeTracking(AnyioTestBase): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.backend = _create_mock_backend(num_channels=8) self.deck = STARLetDeck() self.lh = LiquidHandler(backend=self.backend, deck=self.deck) @@ -1020,11 +1024,10 @@ async def asyncSetUp(self): self.deck.assign_child_resource(self.plate, location=Coordinate(100, 100, 0)) self.single_well_plate = nest_1_troughplate_195000uL_Vb(name="single_well_plate") self.deck.assign_child_resource(self.single_well_plate, location=Coordinate(300, 100, 0)) - await self.lh.setup() - set_volume_tracking(enabled=True) - async def asyncTearDown(self): - set_volume_tracking(enabled=False) + await stack.enter_async_context(self.lh) + set_volume_tracking(enabled=True) + stack.callback(set_volume_tracking, enabled=False) async def test_dispense_with_volume_tracking(self): well = self.plate.get_item("A1") @@ -1116,10 +1119,11 @@ async def test_96_head_volume_tracking_well_list(self): await self.lh.return_tips96() -class TestLiquidHandlerSerializeState(unittest.IsolatedAsyncioTestCase): +class TestLiquidHandlerSerializeState(AnyioTestBase): """Tests for LiquidHandler.serialize_state() and load_state().""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.backend = _create_mock_backend(num_channels=8) self.deck = STARLetDeck() self.lh = LiquidHandler(backend=self.backend, deck=self.deck) @@ -1127,7 +1131,7 @@ async def asyncSetUp(self): self.plate = Cor_96_wellplate_360ul_Fb(name="plate") self.deck.assign_child_resource(self.tip_rack, location=Coordinate(0, 0, 0)) self.deck.assign_child_resource(self.plate, location=Coordinate(100, 100, 0)) - await self.lh.setup() + await stack.enter_async_context(self.lh) async def test_serialize_state_after_setup(self): state = self.lh.serialize_state() @@ -1151,20 +1155,18 @@ async def test_serialize_state_no_head96(self): type(backend).head96_installed = PropertyMock(return_value=False) deck = STARLetDeck() lh = LiquidHandler(backend=backend, deck=deck) - await lh.setup() - - state = lh.serialize_state() - self.assertIsNone(state["head96_state"]) + async with lh: + state = lh.serialize_state() + self.assertIsNone(state["head96_state"]) async def test_serialize_state_no_arms(self): backend = _create_mock_backend(num_channels=8) type(backend).num_arms = PropertyMock(return_value=0) deck = STARLetDeck() lh = LiquidHandler(backend=backend, deck=deck) - await lh.setup() - - state = lh.serialize_state() - self.assertIsNone(state["arm_state"]) + async with lh: + state = lh.serialize_state() + self.assertIsNone(state["arm_state"]) async def test_serialize_state_with_resource_pickup(self): resource = self.plate @@ -1212,10 +1214,11 @@ async def test_load_state_backward_compatible(self): self.lh.load_state(old_state) # should not raise -class TestNoGoZoneIntegration(unittest.IsolatedAsyncioTestCase): +class TestNoGoZoneIntegration(AnyioTestBase): """Integration tests for no-go zone channel distribution through LiquidHandler.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.backend = _create_mock_backend(num_channels=8) self.backend.get_channel_spacings.side_effect = lambda channels: [9.0] * len(channels) self.deck = STARLetDeck() @@ -1236,7 +1239,7 @@ async def asyncSetUp(self): self.tip_rack = hamilton_96_tiprack_300uL_filter(name="tip_rack") self.deck.assign_child_resource(self.tip_rack, location=Coordinate(0, 0, 0)) - await self.lh.setup() + await stack.enter_async_context(self.lh) async def test_aspirate_2_channels_avoids_no_go_zone(self): """2 channels on a trough with a center divider should be placed in separate compartments.""" diff --git a/pylabrobot/machines/backend.py b/pylabrobot/machines/backend.py index 1dab93ce246..1b23648911d 100644 --- a/pylabrobot/machines/backend.py +++ b/pylabrobot/machines/backend.py @@ -1,26 +1,37 @@ +import contextlib import inspect import weakref -from abc import ABC, abstractmethod +from typing import Optional +from pylabrobot.concurrency import AsyncExitStackWithShielding, AsyncResource, global_manager from pylabrobot.serializer import SerializableMixin from pylabrobot.utils.object_parsing import find_subclass -class MachineBackend(SerializableMixin, ABC): +class MachineBackend(SerializableMixin, AsyncResource): """Abstract class for machine backends.""" _instances: weakref.WeakSet["MachineBackend"] = weakref.WeakSet() def __init__(self): self._instances.add(self) + self._stack: Optional[contextlib.AsyncExitStack] = None - @abstractmethod - async def setup(self): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if "setup" in cls.__dict__: + raise TypeError(f"Subclass {cls.__name__} is not allowed to override 'setup'") + if "stop" in cls.__dict__: + raise TypeError(f"Subclass {cls.__name__} is not allowed to override 'stop'") + + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): pass - @abstractmethod + async def setup(self): + await global_manager.manage_context(self) + async def stop(self): - pass + await global_manager.release_context(self) def serialize(self) -> dict: return {"type": self.__class__.__name__} diff --git a/pylabrobot/machines/machine.py b/pylabrobot/machines/machine.py index d241d8809a1..933ca2e94f2 100644 --- a/pylabrobot/machines/machine.py +++ b/pylabrobot/machines/machine.py @@ -2,9 +2,9 @@ import functools import sys -from abc import ABC from typing import Any, Awaitable, Callable, TypeVar +from pylabrobot.concurrency import AsyncExitStackWithShielding, AsyncResource, global_manager from pylabrobot.machines.backend import MachineBackend from pylabrobot.serializer import SerializableMixin @@ -38,16 +38,28 @@ async def wrapper(*args, **kwargs): return wrapper -class Machine(SerializableMixin, ABC): +class Machine(SerializableMixin, AsyncResource): """Abstract base class for machine frontends.""" def __init__(self, backend: MachineBackend): self.backend = backend - self._setup_finished = False + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if "setup" in cls.__dict__: + raise TypeError( + f"Class {cls.__name__} overrides `setup`. " + "Use `_enter_lifespan` instead for structured concurrency." + ) + if "stop" in cls.__dict__: + raise TypeError( + f"Class {cls.__name__} overrides `stop`. " + "Use `_enter_lifespan` instead for structured concurrency." + ) @property def setup_finished(self) -> bool: - return self._setup_finished + return getattr(self, "_active_lifespan", None) is not None def serialize(self) -> dict: return {"backend": self.backend.serialize()} @@ -60,18 +72,17 @@ def deserialize(cls, data: dict): data_copy["backend"] = backend return cls(**data_copy) - async def setup(self, **backend_kwargs): - await self.backend.setup(**backend_kwargs) - self._setup_finished = True - - @need_setup_finished - async def stop(self): - await self.backend.stop() - self._setup_finished = False + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await stack.enter_async_context(self.backend) - async def __aenter__(self): - await self.setup() - return self + async def setup(self, **kwargs): + if kwargs: + # TODO: Design question: Do we need kwargs? We could elevate + # `_lifespan` to a public API `lifespan`, taking kwargs. However, having + # both `lifespan` as well as `__aenter__`/`__aexit__` goes against the + # python ZEN "There should be one, and preferably only one obvious way to do it". + raise ValueError("Keyword arguments during setup are not allowed anymore") + await global_manager.manage_context(self) - async def __aexit__(self, exc_type, exc_value, traceback): - await self.stop() + async def stop(self): + await global_manager.release_context(self) diff --git a/pylabrobot/machines/machine_tests.py b/pylabrobot/machines/machine_tests.py index bdfbb2d506e..c14a477d8e2 100644 --- a/pylabrobot/machines/machine_tests.py +++ b/pylabrobot/machines/machine_tests.py @@ -7,14 +7,9 @@ class TestMachine(unittest.TestCase): class MockBackend(MachineBackend): def __init__(self, mock_param): + super().__init__() self.mock_param = mock_param - async def setup(self): - pass - - async def stop(self): - pass - def serialize(self): return {**super().serialize(), "mock_param": self.mock_param} diff --git a/pylabrobot/microscopes/molecular_devices/pico/backend.py b/pylabrobot/microscopes/molecular_devices/pico/backend.py index 9b5cfacdc5b..65cac0938d0 100644 --- a/pylabrobot/microscopes/molecular_devices/pico/backend.py +++ b/pylabrobot/microscopes/molecular_devices/pico/backend.py @@ -1,4 +1,3 @@ -import asyncio import base64 import hashlib import io @@ -9,6 +8,9 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Tuple, TypeVar +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.sila.grpc import ( command_execution_uuid, decode_command_confirmation, @@ -401,7 +403,7 @@ async def _rpc( ) -> _T: for attempt in range(2): try: - return await asyncio.to_thread(fn) + return await anyio.to_thread.run_sync(fn) except grpc.RpcError as e: if attempt == 0 and with_lock and "CommandRequiresLock" in decode_grpc_error(e): await self._relock() @@ -476,20 +478,37 @@ async def _get_installed_filter_cubes(self) -> List[dict]: # -- lifecycle -- - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + if not HAS_GRPC: raise RuntimeError( f"grpcio is required for the PicoBackend. Import error: {_GRPC_IMPORT_ERROR}" ) - self._channel = grpc.insecure_channel( - f"{self._host}:{self._port}", - options=[ - ("grpc.keepalive_time_ms", 10000), - ("grpc.max_receive_message_length", 64 * 1024 * 1024), - ], + # TODO: We really shouldn't use the sync API here, even if we use thread-hopping. + # There is in fact grcp.aio, which would be a lot cleaner. + self._channel = stack.enter_context( + grpc.insecure_channel( + f"{self._host}:{self._port}", + options=[ + ("grpc.keepalive_time_ms", 10000), + ("grpc.max_receive_message_length", 64 * 1024 * 1024), + ], + ) ) self._lock_id = "pylabrobot" + async def cleanup(): + if self._locked: + try: + await self._unlock() + except (grpc.RpcError, RuntimeError) as e: + logger.warning("PicoBackend: unlock failed during stop: %s", e) + self._channel = None + logger.info("PicoBackend: stopped") + + stack.push_shielded_async_callback(cleanup) + # Try to unlock a stale lock from a previous session that didn't clean up. try: await self._unlock() @@ -518,17 +537,6 @@ async def setup(self) -> None: logger.info("PicoBackend: connected to %s:%d", self._host, self._port) - async def stop(self) -> None: - if self._channel is not None: - if self._locked: - try: - await self._unlock() - except (grpc.RpcError, RuntimeError) as e: - logger.warning("PicoBackend: unlock failed during stop: %s", e) - self._channel.close() - self._channel = None - logger.info("PicoBackend: stopped") - # -- configuration -- async def get_configuration(self) -> dict: diff --git a/pylabrobot/microscopes/molecular_devices/pico/backend_tests.py b/pylabrobot/microscopes/molecular_devices/pico/backend_tests.py index 8e8c123ea65..6eec7986092 100644 --- a/pylabrobot/microscopes/molecular_devices/pico/backend_tests.py +++ b/pylabrobot/microscopes/molecular_devices/pico/backend_tests.py @@ -150,6 +150,12 @@ def __init__(self): def close(self): self.closed = True + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + def set_response(self, path: str, response: bytes): self.responses[path] = response @@ -238,13 +244,15 @@ async def test_setup_sends_correct_sequence(self): ) with patch("grpc.insecure_channel", return_value=channel): - await backend.setup() + async with backend: + self.assertEqual(len(channel.calls), 4) + self.assertEqual(channel.calls[0].path, f"/{_LOCK_SVC}/UnlockServer") + self.assertEqual(channel.calls[1].path, f"/{_LOCK_SVC}/LockServer") + self.assertEqual(channel.calls[2].path, f"/{_OBJ_SVC}/Get_InstalledObjectives") + self.assertEqual(channel.calls[3].path, f"/{_FC_SVC}/Get_InstalledFilterCubes") - self.assertEqual(len(channel.calls), 4) - self.assertEqual(channel.calls[0].path, f"/{_LOCK_SVC}/UnlockServer") - self.assertEqual(channel.calls[1].path, f"/{_LOCK_SVC}/LockServer") - self.assertEqual(channel.calls[2].path, f"/{_OBJ_SVC}/Get_InstalledObjectives") - self.assertEqual(channel.calls[3].path, f"/{_FC_SVC}/Get_InstalledFilterCubes") + self.assertEqual(len(channel.calls), 5) + self.assertEqual(channel.calls[4].path, f"/{_LOCK_SVC}/UnlockServer") # Unlock request contains lock ID self.assertEqual(_decode_sila_string_from_request(channel.calls[0].request), "pylabrobot") @@ -284,7 +292,8 @@ async def test_setup_configures_objectives_and_filter_cubes(self): channel.set_response(f"/{_FC_SVC}/ChangeHardware", b"") with patch("grpc.insecure_channel", return_value=channel): - await backend.setup() + async with backend: + pass # Verify ChangeHardware was called with correct JSON params obj_change_calls = channel.get_calls(f"/{_OBJ_SVC}/ChangeHardware") @@ -307,13 +316,33 @@ async def test_setup_configures_objectives_and_filter_cubes(self): class TestStop(unittest.IsolatedAsyncioTestCase): async def test_stop_sends_unlock(self): - backend, channel = _make_backend() + backend = ExperimentalPicoBackend(host="127.0.0.1") + channel = _MockChannel() - await backend.stop() + channel.set_response(f"/{_LOCK_SVC}/UnlockServer", b"") + channel.set_response(f"/{_LOCK_SVC}/LockServer", b"") + channel.set_response( + f"/{_OBJ_SVC}/Get_InstalledObjectives", + _sila_string_response(json.dumps({"objectivesData": []})), + ) + channel.set_response( + f"/{_FC_SVC}/Get_InstalledFilterCubes", + _sila_string_response(json.dumps({"filterCubesData": []})), + ) - self.assertEqual(len(channel.calls), 1) - self.assertEqual(channel.calls[0].path, f"/{_LOCK_SVC}/UnlockServer") - self.assertEqual(_decode_sila_string_from_request(channel.calls[0].request), "pylabrobot") + with patch("grpc.insecure_channel", return_value=channel): + async with backend: + pass + + # Expected calls: + # 0. UnlockServer (stale) + # 1. LockServer + # 2. Get_InstalledObjectives + # 3. Get_InstalledFilterCubes + # 4. UnlockServer (from cleanup!) + self.assertEqual(len(channel.calls), 5) + self.assertEqual(channel.calls[4].path, f"/{_LOCK_SVC}/UnlockServer") + self.assertEqual(_decode_sila_string_from_request(channel.calls[4].request), "pylabrobot") self.assertTrue(channel.closed) @@ -825,7 +854,3 @@ def test_roundtrip(self): self.assertEqual(meta["blob_checksum"], 42) self.assertEqual(meta["packet_count"], 5) self.assertEqual(meta["packet_index"], 2) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/only_fans/backend.py b/pylabrobot/only_fans/backend.py index 7be5c212785..e1893f06632 100644 --- a/pylabrobot/only_fans/backend.py +++ b/pylabrobot/only_fans/backend.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.machines.backend import MachineBackend @@ -7,7 +8,7 @@ class FanBackend(MachineBackend, metaclass=ABCMeta): """Abstract base class for fan backends.""" @abstractmethod - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: """Set up the fan. This should be called before any other methods.""" @abstractmethod @@ -17,7 +18,3 @@ async def turn_on(self, intensity: int) -> None: @abstractmethod async def turn_off(self) -> None: """Stop the fan, but don't close the connection.""" - - @abstractmethod - async def stop(self) -> None: - """Close all connections to the fan and make sure setup() can be called again.""" diff --git a/pylabrobot/only_fans/chatterbox.py b/pylabrobot/only_fans/chatterbox.py index 25104bb4822..50dc2118873 100644 --- a/pylabrobot/only_fans/chatterbox.py +++ b/pylabrobot/only_fans/chatterbox.py @@ -1,17 +1,20 @@ +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.only_fans import FanBackend class FanChatterboxBackend(FanBackend): """Chatter box backend for device-free testing. Prints out all operations.""" - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: print("Setting up the fan.") + def cleanup(): + print("Stopping the fan.") + + stack.callback(cleanup) + async def turn_on(self, intensity: int) -> None: print(f"Turning on the fan at intensity {intensity}.") async def turn_off(self) -> None: print("Turning off the fan.") - - async def stop(self) -> None: - print("Stopping the fan.") diff --git a/pylabrobot/only_fans/fan.py b/pylabrobot/only_fans/fan.py index c34e784a9ee..d45b3680a0e 100644 --- a/pylabrobot/only_fans/fan.py +++ b/pylabrobot/only_fans/fan.py @@ -1,5 +1,6 @@ -import asyncio +import anyio +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.machines.machine import Machine from .backend import FanBackend @@ -14,9 +15,13 @@ def __init__(self, backend: FanBackend): super().__init__(backend=backend) self.backend: FanBackend = backend # fix type - async def stop(self): - await self.backend.turn_off() - await super().stop() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + + async def cleanup(): + await self.backend.turn_off() + + stack.push_shielded_async_callback(cleanup) async def turn_on(self, intensity: int, duration=None): """Run the fan @@ -29,7 +34,7 @@ async def turn_on(self, intensity: int, duration=None): await self.backend.turn_on(intensity=intensity) if duration is not None: - await asyncio.sleep(duration) + await anyio.sleep(duration) await self.backend.turn_off() async def turn_off(self): diff --git a/pylabrobot/only_fans/hamilton_hepa_fan_backend.py b/pylabrobot/only_fans/hamilton_hepa_fan_backend.py index 0d62c6143f2..e91f84160ec 100644 --- a/pylabrobot/only_fans/hamilton_hepa_fan_backend.py +++ b/pylabrobot/only_fans/hamilton_hepa_fan_backend.py @@ -1,5 +1,6 @@ -import asyncio +import anyio +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.ftdi import FTDI from .backend import FanBackend @@ -13,8 +14,9 @@ def __init__(self, device_id=None): human_readable_device_name="Hamilton HEPA Fan", device_id=device_id, vid=0x0856, pid=0xAC11 ) - async def setup(self): - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) # type: ignore[safe-super] + await stack.enter_async_context(self.io) await self.io.set_baudrate(9600) await self.io.set_line_property(8, 0, 0) # 8N1 await self.io.set_latency_timer(16) @@ -142,12 +144,9 @@ async def turn_on(self, intensity): # Speed is an integer percent between 0 and async def turn_off(self): await self.send(b"\x55\xc1\x01\x11\x00\x7b") - async def stop(self): - await self.io.stop() - async def send(self, command: bytes): await self.io.write(command) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await self.io.read(64) diff --git a/pylabrobot/peeling/xpeel_backend.py b/pylabrobot/peeling/xpeel_backend.py index 8ea1b4a6983..6d8ab2d06ba 100644 --- a/pylabrobot/peeling/xpeel_backend.py +++ b/pylabrobot/peeling/xpeel_backend.py @@ -1,8 +1,9 @@ import logging -import time from dataclasses import dataclass from typing import List, Literal, Tuple +import anyio + try: import serial # type: ignore @@ -11,6 +12,7 @@ HAS_SERIAL = False _SERIAL_IMPORT_ERROR = e +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.peeling.backend import PeelerBackend @@ -71,12 +73,14 @@ def __init__(self, port: str, logger=None, timeout=None): rtscts=False, ) - async def setup(self): - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) + + def cleanup(): + self.logger.info("Serial interface closed.") - async def stop(self): - await self.io.stop() - self.logger.info("Serial interface closed.") + stack.callback(cleanup) @classmethod def describe_error(cls, code: int) -> str: @@ -123,36 +127,36 @@ async def _send_command( await self.io.write(full_cmd.encode("ascii")) responses: List[str] = [] - start = time.time() - while time.time() - start < self.response_timeout: - raw = await self.io.readline() - line = raw.decode("ascii", errors="ignore").strip() - if not line: - continue - - display_line = line - if line.startswith("*ready:"): - parsed = self.parse_ready_line(line) - if parsed: - code, desc = parsed - display_line = f"{line} [{desc}]" - - responses.append(display_line) - self.logger.info(f"Received: {display_line}") - print(f"Received: {display_line}") - - if line.startswith("*ack"): - if not wait_for_ready: + with anyio.move_on_after(self.response_timeout) as scope: + while True: + raw = await self.io.readline() + line = raw.decode("ascii", errors="ignore").strip() + if not line: + continue + + display_line = line + if line.startswith("*ready:"): + parsed = self.parse_ready_line(line) + if parsed: + code, desc = parsed + display_line = f"{line} [{desc}]" + + responses.append(display_line) + self.logger.info(f"Received: {display_line}") + print(f"Received: {display_line}") + + if line.startswith("*ack"): + if not wait_for_ready: + break + continue + + if wait_for_ready and line.startswith("*ready"): break - continue - if wait_for_ready and line.startswith("*ready"): - break - - if not wait_for_ready and not expect_ack: - break + if not wait_for_ready and not expect_ack: + break - if time.time() - start >= self.response_timeout: + if scope.cancel_called: self.logger.warning( "Timed out waiting for response to %s after %.2fs", full_cmd.strip(), diff --git a/pylabrobot/plate_reading/agilent/biotek_backend.py b/pylabrobot/plate_reading/agilent/biotek_backend.py index d119779ba7c..f9733289ef7 100644 --- a/pylabrobot/plate_reading/agilent/biotek_backend.py +++ b/pylabrobot/plate_reading/agilent/biotek_backend.py @@ -1,9 +1,11 @@ -import asyncio import enum import logging import time from typing import Dict, Iterable, List, Optional, Tuple +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.ftdi import FTDI from pylabrobot.plate_reading.backend import PlateReaderBackend from pylabrobot.resources import Plate, Well @@ -78,10 +80,12 @@ def _non_overlapping_rectangles( rects.sort() return rects - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) logger.info(f"{self.__class__.__name__} setting up") - await self.io.setup() + await self.io._enter_lifespan(stack) + await self.io.usb_reset() await self.io.set_latency_timer(16) await self.io.set_baudrate(9600) # 0x38 0x41 @@ -98,14 +102,16 @@ async def setup(self) -> None: self._version = await self.get_firmware_version() self._shaking = False - self._shaking_task: Optional[asyncio.Task] = None + self._shake_cancel_scope: Optional[anyio.CancelScope] = None + self._tg = await stack.enter_async_context(anyio.create_task_group()) - async def stop(self) -> None: - logger.info(f"{self.__class__.__name__} stopping") - await self.stop_shaking() - await self.io.stop() + async def _cleanup(): + logger.info(f"{self.__class__.__name__} stopping") + await self.stop_shaking() + self._slow_mode = None + self._tg.cancel_scope.cancel() - self._slow_mode = None + stack.push_shielded_async_callback(_cleanup) @property def version(self) -> str: @@ -159,17 +165,17 @@ async def _read_until(self, terminator: bytes, timeout: Optional[float] = None) timeout = self.timeout x = None res = b"" - t0 = time.time() - while x != terminator: - x = await self.io.read(1) - res += x - - if time.time() - t0 > timeout: - logger.debug(f"{self.__class__.__name__} received incomplete %s", res) - raise TimeoutError(f"{self.__class__.__name__}: Timeout while waiting for response") + try: + with anyio.fail_after(timeout): + while x != terminator: + x = await self.io.read(1) + res += x - if x == b"": - await asyncio.sleep(0.01) + if x == b"": + await anyio.sleep(0.01) + except TimeoutError: + logger.debug(f"{self.__class__.__name__} received incomplete %s", res) + raise TimeoutError(f"{self.__class__.__name__}: Timeout while waiting for response") from None logger.debug(f"{self.__class__.__name__} received %s", res) return res @@ -545,9 +551,11 @@ async def shake(self, shake_type: ShakeType, frequency: int) -> None: Args: frequency: speed, in mm. 360 CPM = 6mm; 410 CPM = 5mm; 493 CPM = 4mm; 567 CPM = 3mm; 731 CPM = 2mm; 1096 CPM = 1mm """ + assert not self._shaking max_duration = 16 * 60 # 16 minutes - self._shaking_started = asyncio.Event() + shaking_started = anyio.Event() + self._shaking_stopped = anyio.Event() async def shake_maximal_duration(): """This method will start the shaking, but returns immediately after @@ -563,34 +571,28 @@ async def shake_maximal_duration(): resp = await self.send_command("O") assert resp == b"\x060000\x03" - if not self._shaking_started.is_set(): - self._shaking_started.set() + shaking_started.set() async def shake_continuous(): - while self._shaking: - await shake_maximal_duration() - - # short sleep allows = frequent checks for fast stopping - seconds_since_start: float = 0 - loop_wait_time = 0.25 - while seconds_since_start < max_duration and self._shaking: - seconds_since_start += loop_wait_time - await asyncio.sleep(loop_wait_time) + try: + with anyio.CancelScope() as scope: + self._shake_cancel_scope = scope + while True: + await shake_maximal_duration() + await anyio.sleep(max_duration) + finally: + with anyio.CancelScope(shield=True): + await self._abort() + self._shaking = False + self._shaking_stopped.set() + self._shake_cancel_scope = None self._shaking = True - self._shaking_task = asyncio.create_task(shake_continuous()) + self._tg.start_soon(shake_continuous) - await self._shaking_started.wait() + await shaking_started.wait() async def stop_shaking(self) -> None: - if self._shaking: - await self._abort() - self._shaking = False - if self._shaking_task is not None: - self._shaking_task.cancel() - try: - await self._shaking_task - except asyncio.CancelledError: - # Task cancellation is expected here; safe to ignore this exception. - pass - self._shaking_task = None + if self._shake_cancel_scope is not None: + self._shake_cancel_scope.cancel() + await self._shaking_stopped.wait() diff --git a/pylabrobot/plate_reading/agilent/biotek_cytation_backend.py b/pylabrobot/plate_reading/agilent/biotek_cytation_backend.py index 67d87b3f3b3..24f6dfb6b43 100644 --- a/pylabrobot/plate_reading/agilent/biotek_cytation_backend.py +++ b/pylabrobot/plate_reading/agilent/biotek_cytation_backend.py @@ -1,5 +1,4 @@ -import asyncio -import atexit +import contextlib import logging import math import re @@ -8,6 +7,9 @@ from dataclasses import dataclass from typing import List, Literal, Optional, Tuple, Union +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.plate_reading.agilent.biotek_backend import BioTekPlateReaderBackend from pylabrobot.plate_reading.backend import ImagerBackend from pylabrobot.resources import Plate @@ -102,42 +104,82 @@ def __init__( self._objective: Optional[Objective] = None self._acquiring = False - async def setup(self, use_cam: bool = False) -> None: - logger.info(f"{self.__class__.__name__} setting up") + @contextlib.contextmanager + def _spinnaker_system_context(self): + self._spinnaker_system = PySpin.System.GetInstance() + try: + version = self._spinnaker_system.GetLibraryVersion() + logger.debug( + f"{self.__class__.__name__} Library version: %d.%d.%d.%d", + version.major, + version.minor, + version.type, + version.build, + ) - await super().setup() + yield self._spinnaker_system + finally: + # TODO: This looks like a foodgun to me. We are releasing the + # system singleton, without knowing if we're the only user + # of that system. + self._spinnaker_system.ReleaseInstance() + self._spinnaker_system = None - if use_cam: + @contextlib.asynccontextmanager + async def _camera_context(self, cam): + for _ in range(10): try: - await self._set_up_camera() - except: - # if setting up the camera fails, we have to close the ftdi connection - # so that the user can try calling setup() again. - # if we don't close the ftdi connection here, it will be open until the - # python kernel is restarted. - try: - await self.stop() - except Exception: - pass - raise - - async def stop(self): - await super().stop() + cam.Init() # SpinnakerException: Spinnaker: Could not read the XML URL [-1010] + break + except: # noqa + await anyio.sleep(0.1) + else: + raise RuntimeError("Failed to initialize camera.") - if self._acquiring: - self.stop_acquisition() + try: + yield cam + finally: + try: + if self._acquiring: + self.stop_acquisition() + self._reset_trigger() + finally: + cam.DeInit() + + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, use_cam: bool = False): + await super()._enter_lifespan(stack) + logger.info(f"{self.__class__.__name__} setting up") + + if use_cam: + if not USE_PYSPIN: + raise RuntimeError( + "PySpin is not installed. Please follow the imaging setup instructions. " + f"Import error: {_PYSPIN_IMPORT_ERROR}" + ) + if self.imaging_config is None: + raise RuntimeError("Imaging configuration is not set.") + + spinnaker_sys = stack.enter_context(self._spinnaker_system_context()) + cam = self._identify_camera(spinnaker_sys) + await stack.enter_async_context(self._camera_context(cam)) + self._cam = cam - logger.info(f"{self.__class__.__name__} stopping") - await self.stop_shaking() - await self.io.stop() + nodemap = self._cam.GetNodeMap() + await self._setup_trigger(nodemap) - self._stop_camera() + # -- Load filter information -- + if self._filters is None: + await self._load_filters() - self._objectives = None - self._filters = None - self._slow_mode = None + # -- Load objective information -- + if self._objectives is None: + await self._load_objectives() - self._clear_imaging_state() + @stack.callback + def _cleanup_always(): + self._objectives = None + self._filters = None + self._clear_imaging_state() def _clear_imaging_state(self): self._exposure = None @@ -157,82 +199,48 @@ def supports_heating(self): def supports_cooling(self): return True - async def _set_up_camera(self) -> None: - atexit.register(self._stop_camera) - - if not USE_PYSPIN: - raise RuntimeError( - "PySpin is not installed. Please follow the imaging setup instructions. " - f"Import error: {_PYSPIN_IMPORT_ERROR}" - ) - if self.imaging_config is None: - raise RuntimeError("Imaging configuration is not set.") - - logger.debug(f"{self.__class__.__name__} setting up camera") - - # -- Retrieve singleton reference to system object (Spinnaker) -- - self._spinnaker_system = PySpin.System.GetInstance() - version = self._spinnaker_system.GetLibraryVersion() - logger.debug( - f"{self.__class__.__name__} Library version: %d.%d.%d.%d", - version.major, - version.minor, - version.type, - version.build, - ) - - # -- Get the camera by serial number, or the first. -- - cam_list = self._spinnaker_system.GetCameras() - num_cameras = cam_list.GetSize() - logger.debug(f"{self.__class__.__name__} number of cameras detected: %d", num_cameras) - - for cam in cam_list: - info = self._get_device_info(cam) - serial_number = info["DeviceSerialNumber"] - logger.debug(f"{self.__class__.__name__} camera detected: %s", serial_number) + def _identify_camera(self, spinnaker_sys) -> "PySpin.Camera": - if ( - self.imaging_config.camera_serial_number is not None - and serial_number == self.imaging_config.camera_serial_number - ): - self._cam = cam - logger.info(f"{self.__class__.__name__} using camera with serial number %s", serial_number) - break - else: # if no specific camera was found by serial number so use the first one - if num_cameras > 0: - self._cam = cam_list.GetByIndex(0) - logger.info( - f"{self.__class__.__name__} using first camera with serial number %s", - info["DeviceSerialNumber"], - ) - else: - logger.error(f"{self.__class__.__name__}: No cameras found") - self._cam = None - cam_list.Clear() + cam_list = spinnaker_sys.GetCameras() + try: + num_cameras = cam_list.GetSize() + logger.debug(f"{self.__class__.__name__} number of cameras detected: %d", num_cameras) + + target_cam = None + for cam in cam_list: + info = self._get_device_info(cam) + serial_number = info["DeviceSerialNumber"] + logger.debug(f"{self.__class__.__name__} camera detected: %s", serial_number) + + if ( + self.imaging_config.camera_serial_number is not None + and serial_number == self.imaging_config.camera_serial_number + ): + target_cam = cam + logger.info( + f"{self.__class__.__name__} using camera with serial number %s", serial_number + ) + break + else: # if no specific camera was found by serial number so use the first one + if num_cameras > 0: + target_cam = cam_list.GetByIndex(0) + info = self._get_device_info(target_cam) + logger.info( + f"{self.__class__.__name__} using first camera with serial number %s", + info["DeviceSerialNumber"], + ) + else: + logger.error(f"{self.__class__.__name__}: No cameras found") + target_cam = None + finally: + cam_list.Clear() - if self._cam is None: - raise RuntimeError( - f"{self.__class__.__name__}: No camera found. Make sure the camera is connected and the serial " - "number is correct." - ) + if target_cam is None: + raise RuntimeError(f"{self.__class__.__name__}: No camera found.") - # -- Initialize camera -- - for _ in range(10): - try: - self._cam.Init() # SpinnakerException: Spinnaker: Could not read the XML URL [-1010] - break - except: # noqa - await asyncio.sleep(0.1) - pass - else: - raise RuntimeError( - "Failed to initialize camera. Make sure the camera is connected and the " - "Spinnaker SDK is installed correctly." - ) - nodemap = self._cam.GetNodeMap() + return target_cam - # -- Configure trigger to be software -- - # This is needed for longer exposure times (otherwise 27.8ms is the maximum) + async def _setup_trigger(self, nodemap): # 1. Set trigger selector to frame start ptr_trigger_selector = PySpin.CEnumerationPtr(nodemap.GetNode("TriggerSelector")) if not PySpin.IsReadable(ptr_trigger_selector) or not PySpin.IsWritable(ptr_trigger_selector): @@ -263,15 +271,7 @@ async def _set_up_camera(self) -> None: ptr_trigger_mode.SetIntValue(int(ptr_trigger_on.GetNumericValue())) # "NOTE: Blackfly and Flea3 GEV cameras need 1 second delay after trigger mode is turned on" - await asyncio.sleep(1) - - # -- Load filter information -- - if self._filters is None: - await self._load_filters() - - # -- Load objective information -- - if self._objectives is None: - await self._load_objectives() + await anyio.sleep(1) @property def objectives(self) -> List[Optional[Objective]]: @@ -422,18 +422,6 @@ async def _load_objectives(self): else: raise RuntimeError(f"{self.__class__.__name__}: Unsupported version: {self.version}") - def _stop_camera(self) -> None: - if self._cam is not None: - if self._acquiring: - self.stop_acquisition() - - self._reset_trigger() - - self._cam.DeInit() - self._cam = None - if self._spinnaker_system is not None: - self._spinnaker_system.ReleaseInstance() - def _reset_trigger(self): if self._cam is None: return @@ -601,7 +589,7 @@ async def set_position(self, x: float, y: float): await self.send_command("Y", f"O01{relative_y_str}") self._pos_x, self._pos_y = x, y - await asyncio.sleep(0.1) + await anyio.sleep(0.1) async def set_auto_exposure(self, auto_exposure: Literal["off", "once", "continuous"]): if self._cam is None: @@ -799,7 +787,7 @@ async def _acquire_image( try: node_softwaretrigger_cmd.Execute() timeout = int(self._cam.ExposureTime.GetValue() / 1000 + 1000) # from example - image_result = self._cam.GetNextImage(timeout) + image_result = await anyio.to_thread.run_sync(self._cam.GetNextImage, timeout) if not image_result.IsIncomplete(): processor = PySpin.ImageProcessor() processor.SetColorProcessing(color_processing_algorithm) @@ -817,7 +805,7 @@ async def _acquire_image( ) num_tries += 1 - await asyncio.sleep(0.3) + await anyio.sleep(0.3) raise TimeoutError("max_image_read_attempts reached") async def capture( diff --git a/pylabrobot/plate_reading/agilent/biotek_synergyh1_backend.py b/pylabrobot/plate_reading/agilent/biotek_synergyh1_backend.py index 5036bb33b83..8bc98226582 100644 --- a/pylabrobot/plate_reading/agilent/biotek_synergyh1_backend.py +++ b/pylabrobot/plate_reading/agilent/biotek_synergyh1_backend.py @@ -1,8 +1,8 @@ -import asyncio import logging -import time from typing import Optional +import anyio + try: from pylibftdi import FtdiError @@ -37,52 +37,53 @@ async def _read_until( if timeout is None: timeout = self.timeout - deadline = time.time() + timeout buf = bytearray() retries = 0 max_retries = 3 - while True: - if time.time() > deadline: - logger.debug( - f"{self.__class__.__name__} _read_until timed out; partial buffer (hex): %s", buf.hex() - ) - raise TimeoutError( - f"{self.__class__.__name__} _read_until timed out waiting for {terminator!r}; partial={buf.hex()}" - ) - - try: - data = await self.io.read(chunk_size) - if len(data) == 0: - await asyncio.sleep(0.02) - continue - - buf.extend(data) - - if terminator in buf: - idx = buf.index(terminator) + len(terminator) - full = bytes(buf[:idx]) - logger.debug( - f"{self.__class__.__name__} _read_until received %d bytes (hex prefix): %s", - len(full), - full[:200].hex(), - ) - return full - - except FtdiError as e: - retries += 1 - logger.warning( - f"{self.__class__.__name__} transient FtdiError while reading: %s — retrying", e - ) - - if retries >= max_retries: - logger.warning( - f"{self.__class__.__name__} too many FtdiError retries ({max_retries}) — stopping", e - ) - raise - - await asyncio.sleep(0.05) - continue - except Exception: - raise + try: + with anyio.fail_after(timeout): + while True: + try: + data = await self.io.read(chunk_size) + if len(data) == 0: + await anyio.sleep(0.02) + continue + + buf.extend(data) + + if terminator in buf: + idx = buf.index(terminator) + len(terminator) + full = bytes(buf[:idx]) + logger.debug( + f"{self.__class__.__name__} _read_until received %d bytes (hex prefix): %s", + len(full), + full[:200].hex(), + ) + return full + + except FtdiError as e: + retries += 1 + logger.warning( + f"{self.__class__.__name__} transient FtdiError while reading: %s — retrying", e + ) + + if retries >= max_retries: + logger.warning( + f"{self.__class__.__name__} too many FtdiError retries ({max_retries}) — stopping", + e, + ) + raise + + await anyio.sleep(0.05) + continue + except Exception: + raise + except TimeoutError: + logger.debug( + f"{self.__class__.__name__} _read_until timed out; partial buffer (hex): %s", buf.hex() + ) + raise TimeoutError( + f"{self.__class__.__name__} _read_until timed out waiting for {terminator!r}; partial={buf.hex()}" + ) diff --git a/pylabrobot/plate_reading/agilent/biotek_tests.py b/pylabrobot/plate_reading/agilent/biotek_tests.py index d011901249f..ccd8e2024c6 100644 --- a/pylabrobot/plate_reading/agilent/biotek_tests.py +++ b/pylabrobot/plate_reading/agilent/biotek_tests.py @@ -1,13 +1,15 @@ # mypy: disable-error-code = attr-defined +import contextlib import math -import unittest import unittest.mock from typing import Iterator import pytest +from pylabrobot.testing.concurrency import AnyioTestBase + pytest.importorskip("pylibftdi") from pylabrobot.plate_reading.agilent.biotek_cytation_backend import CytationBackend @@ -19,14 +21,14 @@ def _byte_iter(s: str) -> Iterator[bytes]: yield c.encode() -class TestCytation5Backend(unittest.IsolatedAsyncioTestCase): +class TestCytation5Backend(AnyioTestBase): """Tests for the Cytation5Backend.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): + await super()._enter_lifespan(stack) self.backend = CytationBackend(timeout=0.1) self.backend.io = unittest.mock.MagicMock() - self.backend.io.setup = unittest.mock.AsyncMock() - self.backend.io.stop = unittest.mock.AsyncMock() + self.backend.io._enter_lifespan = unittest.mock.AsyncMock() self.backend.io.read = unittest.mock.AsyncMock() self.backend.io.write = unittest.mock.AsyncMock() self.backend.io.usb_reset = unittest.mock.AsyncMock() @@ -40,22 +42,18 @@ async def asyncSetUp(self): self.plate = CellVis_24_wellplate_3600uL_Fb(name="plate") # Mock time.time() to control the timestamp in the results - self.mock_time = unittest.mock.patch("time.time", return_value=12345.6789).start() - self.addCleanup(self.mock_time.stop) + stack.enter_context(unittest.mock.patch("time.time", return_value=12345.6789)) async def test_setup(self): self.backend.io.read.side_effect = _byte_iter("\x061650200 Version 1.04 0000\x03") - await self.backend.setup() - assert self.backend.io.setup.called - self.backend.io.usb_reset.assert_called_once() - self.backend.io.set_latency_timer.assert_called_with(16) - self.backend.io.set_baudrate.assert_called_with(9600) - # self.backend.io.set_line_property.assert_called_with(8, 2, 0) #? - self.backend.io.set_flowctrl.assert_called_with(0x100) - self.backend.io.set_rts.assert_called_with(1) - - await self.backend.stop() - assert self.backend.io.stop.called + async with self.backend: + assert self.backend.io._enter_lifespan.called + self.backend.io.usb_reset.assert_called_once() + self.backend.io.set_latency_timer.assert_called_with(16) + self.backend.io.set_baudrate.assert_called_with(9600) + # self.backend.io.set_line_property.assert_called_with(8, 2, 0) #? + self.backend.io.set_flowctrl.assert_called_with(0x100) + self.backend.io.set_rts.assert_called_with(1) async def test_get_serial_number(self): self.backend.io.read.side_effect = _byte_iter("\x0600000000 0000\x03") diff --git a/pylabrobot/plate_reading/backend.py b/pylabrobot/plate_reading/backend.py index f793e18a023..a0d1b91d6b0 100644 --- a/pylabrobot/plate_reading/backend.py +++ b/pylabrobot/plate_reading/backend.py @@ -20,14 +20,6 @@ class PlateReaderBackend(MachineBackend, metaclass=ABCMeta): """An abstract class for a plate reader. Plate readers are devices that can read luminescence, absorbance, or fluorescence from a plate.""" - @abstractmethod - async def setup(self) -> None: - """Set up the plate reader. This should be called before any other methods.""" - - @abstractmethod - async def stop(self) -> None: - """Close all connections to the plate reader and make sure setup() can be called again.""" - @abstractmethod async def open(self) -> None: """Open the plate reader. Also known as plate out.""" diff --git a/pylabrobot/plate_reading/bmg_labtech/clario_star_backend.py b/pylabrobot/plate_reading/bmg_labtech/clario_star_backend.py index f3459aa138e..c4fbbc1bf6b 100644 --- a/pylabrobot/plate_reading/bmg_labtech/clario_star_backend.py +++ b/pylabrobot/plate_reading/bmg_labtech/clario_star_backend.py @@ -1,4 +1,3 @@ -import asyncio import logging import math import struct @@ -6,7 +5,10 @@ import time from typing import Dict, List, Optional, Tuple, Union +import anyio + from pylabrobot import utils +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.ftdi import FTDI from pylabrobot.resources.plate import Plate from pylabrobot.resources.well import Well @@ -30,8 +32,9 @@ def __init__(self, device_id: Optional[str] = None): human_readable_device_name="BMG CLARIOstar", device_id=device_id, vid=0x0403, pid=0xBB68 ) - async def setup(self): - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) await self.io.set_baudrate(125000) await self.io.set_line_property(8, 0, 0) # 8N1 await self.io.set_latency_timer(2) @@ -39,9 +42,6 @@ async def setup(self): await self.initialize() await self.request_eeprom_data() - async def stop(self): - await self.io.stop() - async def get_stat(self): stat = await self.io.poll_modem_status() return hex(stat) @@ -51,35 +51,32 @@ async def read_resp(self, timeout=20) -> bytes: been read so far.""" d = b"" - last_read = b"" end_byte_found = False - t = time.time() # Commands are terminated with 0x0d, but this value may also occur as a part of the response. # Therefore, we read until we read a 0x0d, but if that's the last byte we read in a full packet, # we keep reading for at least one more cycle. We only check the timeout if the last read was # unsuccessful (i.e. keep reading if we are still getting data). - while True: - last_read = await self.io.read(25) # 25 is max length observed in pcap - if len(last_read) > 0: - d += last_read - end_byte_found = d[-1] == 0x0D - if ( - len(last_read) < 25 and end_byte_found - ): # if we read less than 25 bytes, we're at the end - break - else: - # If we didn't read any data, check if the last read ended in an end byte. If so, we're done - if end_byte_found: - break - - # Check if we've timed out. - if time.time() - t > timeout: - logger.warning("timed out reading response") - break - - # If we read data, we don't wait and immediately try to read more. - await asyncio.sleep(0.0001) + with anyio.move_on_after(timeout) as scope: + while True: + last_read = await self.io.read(25) # 25 is max length observed in pcap + if len(last_read) > 0: + d += last_read + end_byte_found = d[-1] == 0x0D + if ( + len(last_read) < 25 and end_byte_found + ): # if we read less than 25 bytes, we're at the end + break + else: + # If we didn't read any data, check if the last read ended in an end byte. If so, we're done + if end_byte_found: + break + + # If we read data, we don't wait and immediately try to read more. + await anyio.sleep(0.0001) + + if scope.cancel_called: + logger.warning("timed out reading response") logger.debug("read %s", d.hex()) @@ -105,40 +102,40 @@ async def send(self, cmd: Union[bytearray, bytes], read_timeout=20): async def _wait_for_ready_and_return(self, ret, timeout=150): """Wait for the plate reader to be ready and return the response.""" last_status = None - t = time.time() - while time.time() - t < timeout: - await asyncio.sleep(0.1) - - command_status = await self.read_command_status() - - if len(command_status) != 24: - logger.warning( - "unexpected response %s. I think a command status response is always 24 bytes", - command_status, - ) - continue - - if command_status != last_status: - logger.info("status changed %s", command_status.hex()) - last_status = command_status - else: - continue - - if command_status[2] != 0x18 or command_status[3] != 0x0C or command_status[4] != 0x01: - logger.warning( - "unexpected response %s. I think 18 0c 01 indicates a command status response", - command_status, - ) - - if command_status[5] not in { - 0x25, - 0x05, - }: # 25 is busy, 05 is ready. probably. - logger.warning("unexpected response %s.", command_status) - - if command_status[5] == 0x05: - logger.debug("status is ready") - return ret + with anyio.fail_after(timeout): + while True: + await anyio.sleep(0.1) + + command_status = await self.read_command_status() + + if len(command_status) != 24: + logger.warning( + "unexpected response %s. I think a command status response is always 24 bytes", + command_status, + ) + continue + + if command_status != last_status: + logger.info("status changed %s", command_status.hex()) + last_status = command_status + else: + continue + + if command_status[2] != 0x18 or command_status[3] != 0x0C or command_status[4] != 0x01: + logger.warning( + "unexpected response %s. I think 18 0c 01 indicates a command status response", + command_status, + ) + + if command_status[5] not in { + 0x25, + 0x05, + }: # 25 is busy, 05 is ready. probably. + logger.warning("unexpected response %s.", command_status) + + if command_status[5] == 0x05: + logger.debug("status is ready") + return ret async def read_command_status(self): status = await self.send(b"\x02\x00\x09\x0c\x80\x00") @@ -227,7 +224,7 @@ async def _run_luminescence(self, focal_height: float, plate: Plate): # TODO: find a prettier way to do this. It's essentially copied from _wait_for_ready_and_return. last_status = None while True: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) command_status = await self.read_command_status() @@ -259,7 +256,7 @@ async def _run_absorbance(self, wavelength: float, plate: Plate): # TODO: find a prettier way to do this. It's essentially copied from _wait_for_ready_and_return. last_status = None while True: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) command_status = await self.read_command_status() diff --git a/pylabrobot/plate_reading/byonoy/byonoy_backend.py b/pylabrobot/plate_reading/byonoy/byonoy_backend.py index ca2ae684cf3..7f0e82c380e 100644 --- a/pylabrobot/plate_reading/byonoy/byonoy_backend.py +++ b/pylabrobot/plate_reading/byonoy/byonoy_backend.py @@ -1,10 +1,11 @@ import abc -import asyncio import enum -import threading import time from typing import Dict, List, Optional +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.binary import Reader, Writer from pylabrobot.io.hid import HID from pylabrobot.plate_reading.backend import PlateReaderBackend @@ -24,31 +25,22 @@ class _ByonoyBase(PlateReaderBackend, metaclass=abc.ABCMeta): def __init__(self, pid: int, device_type: _ByonoyDevice) -> None: self.io = HID(human_readable_device_name="Byonoy Plate Reader", vid=0x16D0, pid=pid) - self._background_thread: Optional[threading.Thread] = None - self._stop_background = threading.Event() self._ping_interval = 1.0 # Send ping every second self._sending_pings = False # Whether to actively send pings self._device_type = device_type - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """Set up the plate reader. This should be called before any other methods.""" + await super()._enter_lifespan(stack) - await self.io.setup() + await stack.enter_async_context(self.io) # Start background keep alive messages - self._stop_background.clear() - self._background_thread = threading.Thread(target=self._background_ping_worker, daemon=True) - self._background_thread.start() - - async def stop(self) -> None: - """Close all connections to the plate reader and make sure setup() can be called again.""" - # Stop background keep alive messages - self._stop_background.set() - if self._background_thread and self._background_thread.is_alive(): - self._background_thread.join(timeout=2.0) + tg = await stack.enter_async_context(anyio.create_task_group()) + stack.callback(tg.cancel_scope.cancel) - await self.io.stop() + tg.start_soon(self._ping_loop) def _assemble_command(self, report_id: int, payload: bytes, routing_info: bytes) -> bytes: packet = Writer().u16(report_id).raw_bytes(payload).finish() @@ -68,34 +60,24 @@ async def send_command( if not wait_for_response: return None - t0 = time.time() - while True: - if time.time() - t0 > 120: # read for 2 minutes max. typical is 1m5s. - raise TimeoutError("Reading luminescence data timed out after 2 minutes.") - - response = await self.io.read(64, timeout=30) - if len(response) == 0: - continue - - # if the first 2 bytes do not match, we continue reading - response_report_id = Reader(response).u16() - if report_id == response_report_id: - break - return response - - def _background_ping_worker(self) -> None: - """Background worker that sends periodic ping commands.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(self._ping_loop()) - finally: - loop.close() + with anyio.fail_after(120): + while True: + response = await self.io.read(64, timeout=30) + if len(response) == 0: + continue + + # if the first 2 bytes do not match, we continue reading + response_report_id = Reader(response).u16() + if report_id == response_report_id: + break + except TimeoutError: + raise TimeoutError("Timeout waiting for response from Byonoy device") from None + return response async def _ping_loop(self) -> None: - """Main ping loop that runs in the background thread.""" - while not self._stop_background.is_set(): + """Main ping loop that runs in the background.""" + while True: if self._sending_pings: # don't read in background thread, data might get lost here. don't use send_command payload = Writer().u8(1).finish() @@ -106,7 +88,7 @@ async def _ping_loop(self) -> None: ) await self.io.write(cmd) - self._stop_background.wait(self._ping_interval) + await anyio.sleep(self._ping_interval) def _start_background_pings(self) -> None: self._sending_pings = True @@ -129,11 +111,9 @@ class ByonoyAbsorbance96AutomateBackend(_ByonoyBase): def __init__(self) -> None: super().__init__(pid=0x1199, device_type=_ByonoyDevice.ABSORBANCE_96) - async def setup(self, verbose: bool = False, **backend_kwargs): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """Set up the plate reader. This should be called before any other methods.""" - - # Call the base setup (opens HID) - await super().setup(**backend_kwargs) + await super()._enter_lifespan(stack) # After device is online, run reference initialisation await self.initialize_measurements() @@ -197,34 +177,34 @@ async def _run_abs_measurement(self, signal_wl: int, reference_wl: int, is_refer # (4) Collect chunks (report_id 0x0500) rows: List[float] = [] - t0 = time.time() - - while True: - if time.time() - t0 > 120: - raise TimeoutError("Measurement timeout.") - chunk = await self.io.read(64, timeout=30) - if len(chunk) == 0: - continue - - reader = Reader(chunk) - report_id = reader.u16() - - # Only handle the measurement packets - if report_id == 0x0500: - seq = reader.u8() - seq_len = reader.u8() - _ = reader.i16() # signal_wl_nm - _ = reader.i16() # reference_wl_nm - _ = reader.u32() # duration_ms - row = [reader.f32() for _ in range(12)] - _ = reader.u8() # flags - _ = reader.u8() # progress - - rows.extend(row) - - if seq == seq_len - 1: - break + try: + with anyio.fail_after(120): + while True: + chunk = await self.io.read(64, timeout=30) + if len(chunk) == 0: + continue + + reader = Reader(chunk) + report_id = reader.u16() + + # Only handle the measurement packets + if report_id == 0x0500: + seq = reader.u8() + seq_len = reader.u8() + _ = reader.i16() # signal_wl_nm + _ = reader.i16() # reference_wl_nm + _ = reader.u32() # duration_ms + row = [reader.f32() for _ in range(12)] + _ = reader.u8() # flags + _ = reader.u8() # progress + + rows.extend(row) + + if seq == seq_len - 1: + break + except TimeoutError: + raise TimeoutError("Timeout waiting for measurement data from Byonoy device") from None return rows @@ -345,33 +325,33 @@ async def read_luminescence( wait_for_response=False, ) - t0 = time.time() all_rows: List[float] = [] - while True: - if time.time() - t0 > 120: # read for 2 minutes max. typical is 1m5s. - raise TimeoutError("Reading luminescence data timed out after 2 minutes.") - - chunk = await self.io.read(64, timeout=30) - if len(chunk) == 0: - continue - - reader = Reader(chunk) - report_id = reader.u16() - - if report_id == 0x0600: # REP_LUM96_MEASUREMENT_IN - seq = reader.u8() - seq_len = reader.u8() - _ = reader.u32() # integration_time_us - _ = reader.u32() # duration_ms - row = [reader.f32() for _ in range(12)] - _ = reader.u8() # flags - _ = reader.u8() # progress - - all_rows.extend(row) - - if seq == seq_len - 1: - break + try: + with anyio.fail_after(120): + while True: + chunk = await self.io.read(64, timeout=30) + if len(chunk) == 0: + continue + + reader = Reader(chunk) + report_id = reader.u16() + + if report_id == 0x0600: # REP_LUM96_MEASUREMENT_IN + seq = reader.u8() + seq_len = reader.u8() + _ = reader.u32() # integration_time_us + _ = reader.u32() # duration_ms + row = [reader.f32() for _ in range(12)] + _ = reader.u8() # flags + _ = reader.u8() # progress + + all_rows.extend(row) + + if seq == seq_len - 1: + break + except TimeoutError: + raise TimeoutError("Timeout waiting for luminescence data from Byonoy device") from None hybrid_result = all_rows[96 * 0 : 96 * 1] _ = all_rows[96 * 1 : 96 * 2] # counting_result diff --git a/pylabrobot/plate_reading/byonoy/byonoy_backend_tests.py b/pylabrobot/plate_reading/byonoy/byonoy_backend_tests.py new file mode 100644 index 00000000000..74d1408c93a --- /dev/null +++ b/pylabrobot/plate_reading/byonoy/byonoy_backend_tests.py @@ -0,0 +1,41 @@ +import contextlib +import unittest.mock + +import anyio +import pytest + +from pylabrobot.plate_reading.byonoy.byonoy_backend import ByonoyAbsorbance96AutomateBackend +from pylabrobot.testing.concurrency import AnyioTestBase + + +class TestByonoyBackend(AnyioTestBase): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): + await super()._enter_lifespan(stack) + self.backend = ByonoyAbsorbance96AutomateBackend() + self.backend.io = unittest.mock.AsyncMock() + + self.backend.get_available_absorbance_wavelengths = unittest.mock.AsyncMock( # type: ignore[method-assign] + return_value=[450, 660] + ) + self.backend.initialize_measurements = unittest.mock.AsyncMock() # type: ignore[method-assign] + + @pytest.mark.parametrize("backend", ["asyncio", "trio"]) + async def test_setup(self): + + async with self.backend: + assert self.backend.io.__aenter__.called # type: ignore[attr-defined] + assert self.backend.initialize_measurements.called # type: ignore[attr-defined] + + assert self.backend.available_wavelengths == [450, 660] + + # Verify ping loop is running by checking if write was called (if sending_pings is True) + # Wait, sending_pings defaults to False! + assert not self.backend._sending_pings + + # Enable pings + self.backend._start_background_pings() + assert self.backend._sending_pings + + # Wait for a bit to let ping loop run + await anyio.sleep(1.5) + assert self.backend.io.write.called diff --git a/pylabrobot/plate_reading/chatterbox.py b/pylabrobot/plate_reading/chatterbox.py index 4def8f3a806..84fd947de68 100644 --- a/pylabrobot/plate_reading/chatterbox.py +++ b/pylabrobot/plate_reading/chatterbox.py @@ -1,6 +1,7 @@ import time from typing import Dict, List, Optional +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.plate_reading.backend import PlateReaderBackend from pylabrobot.resources import Plate, Well @@ -14,11 +15,10 @@ def __init__(self): self.dummy_absorbance: List[List[Optional[float]]] = [[0.0] * 12] * 8 self.dummy_fluorescence: List[List[Optional[float]]] = [[0.0] * 12] * 8 - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up the plate reader.") - - async def stop(self) -> None: - print("Stopping the plate reader.") + stack.callback(lambda: print("Stopping the plate reader.")) async def open(self) -> None: print("Opening the plate reader.") diff --git a/pylabrobot/plate_reading/imager.py b/pylabrobot/plate_reading/imager.py index c55e0737d7f..8fd6d3166a7 100644 --- a/pylabrobot/plate_reading/imager.py +++ b/pylabrobot/plate_reading/imager.py @@ -1,8 +1,9 @@ import logging import math -import time from typing import Any, Awaitable, Callable, Coroutine, Dict, Literal, Optional, Tuple, Union, cast +import anyio + from pylabrobot.machines import Machine, need_setup_finished from pylabrobot.plate_reading.backend import ImagerBackend from pylabrobot.plate_reading.standard import ( @@ -55,19 +56,20 @@ async def cached_func(x: float) -> float: cache[x] = await func(x) return cache[x] - t0 = time.time() iteration = 0 - while abs(b - a) > tol: - if (await cached_func(c)) > (await cached_func(d)): - b = d - else: - a = c - c = b - (b - a) / phi - d = a + (b - a) / phi - if time.time() - t0 > timeout: - raise TimeoutError("Timeout while searching for optimal focus position") - iteration += 1 - logger.debug("Golden ratio search (autofocus) iteration %d, a=%s, b=%s", iteration, a, b) + try: + with anyio.fail_after(timeout): + while abs(b - a) > tol: + if (await cached_func(c)) > (await cached_func(d)): + b = d + else: + a = c + c = b - (b - a) / phi + d = a + (b - a) / phi + iteration += 1 + logger.debug("Golden ratio search (autofocus) iteration %d, a=%s, b=%s", iteration, a, b) + except TimeoutError: + raise TimeoutError(f"Autofocus did not converge within {timeout} seconds") from None return (b + a) / 2 diff --git a/pylabrobot/plate_reading/molecular_devices/backend.py b/pylabrobot/plate_reading/molecular_devices/backend.py index fa573388958..d4a0d06f697 100644 --- a/pylabrobot/plate_reading/molecular_devices/backend.py +++ b/pylabrobot/plate_reading/molecular_devices/backend.py @@ -1,12 +1,13 @@ -import asyncio import logging import re -import time from abc import ABCMeta from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Literal, Optional, Tuple, Union +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.plate_reading.backend import PlateReaderBackend from pylabrobot.resources.plate import Plate @@ -258,13 +259,11 @@ def __init__(self, port: str) -> None: timeout=0.2, ) - async def setup(self) -> None: - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) await self.send_command("!") - async def stop(self) -> None: - await self.io.stop() - def serialize(self) -> dict: return {**super().serialize(), "port": self.port} @@ -282,14 +281,16 @@ async def send_command( await self.io.write(command.encode() + b"\r") raw_response = b"" - timeout_time = time.time() + timeout - while True: - raw_response += await self.io.readline() - await asyncio.sleep(0.001) - if time.time() > timeout_time: - raise TimeoutError(f"Timeout waiting for response to command: {command}") - if raw_response.count(RES_TERM_CHAR) >= num_res_fields: - break + try: + with anyio.fail_after(timeout): + while True: + raw_response += await self.io.readline() + await anyio.sleep(0.001) + if raw_response.count(RES_TERM_CHAR) >= num_res_fields: + break + except TimeoutError: + raise TimeoutError(f"Timeout waiting for response to command: {command}") from None + logger.debug("[plate reader] Command: %s, Response: %s", command, raw_response) response = raw_response.decode("utf-8", errors="replace").strip().split(RES_TERM_CHAR.decode()) response = [r.strip() for r in response if r.strip() != ""] @@ -682,14 +683,15 @@ def _get_cutoff_filter_index_from_wavelength(self, wavelength: int) -> int: async def _wait_for_idle(self, timeout: int = 600): """Wait for the plate reader to become idle.""" - start_time = time.time() - while True: - if time.time() - start_time > timeout: - raise TimeoutError("Timeout waiting for plate reader to become idle.") - status = await self.get_status() - if status and status[1] == "IDLE": - break - await asyncio.sleep(1) + try: + with anyio.fail_after(timeout): + while True: + status = await self.get_status() + if status and status[1] == "IDLE": + break + await anyio.sleep(1) + except TimeoutError: + raise TimeoutError("Timeout waiting for plate reader to become idle.") from None async def read_absorbance( # type: ignore[override] self, diff --git a/pylabrobot/plate_reading/molecular_devices/backend_tests.py b/pylabrobot/plate_reading/molecular_devices/backend_tests.py index 589d65386cb..441b2a5fc1a 100644 --- a/pylabrobot/plate_reading/molecular_devices/backend_tests.py +++ b/pylabrobot/plate_reading/molecular_devices/backend_tests.py @@ -1,7 +1,8 @@ import math -import unittest from unittest.mock import AsyncMock, MagicMock, call, patch +import pytest + from pylabrobot.plate_reading.molecular_devices.backend import ( Calibrate, CarriageSpeed, @@ -18,38 +19,46 @@ SpectrumSettings, ) from pylabrobot.resources.agenbio.plates import AGenBio_96_wellplate_Ub_2200ul +from pylabrobot.testing.concurrency import AnyioTestBase -class TestMolecularDevicesBackend(unittest.IsolatedAsyncioTestCase): +class TestMolecularDevicesBackend(AnyioTestBase): backend: MolecularDevicesBackend mock_serial: MagicMock send_command_mock: AsyncMock - def setUp(self): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) + self.mock_serial = MagicMock() - self.mock_serial.setup = AsyncMock() - self.mock_serial.stop = AsyncMock() + self.mock_serial.__aenter__ = AsyncMock(return_value=self.mock_serial) + self.mock_serial.__aexit__ = AsyncMock(return_value=None) self.mock_serial.write = AsyncMock() self.mock_serial.readline = AsyncMock(return_value=b"OK>\r\n") - with patch("pylabrobot.io.serial.Serial", return_value=self.mock_serial): - self.backend = MolecularDevicesBackend(port="COM1") - self.backend.io = self.mock_serial - self.send_command_mock = patch.object( - self.backend, "send_command", new_callable=AsyncMock - ).start() - self.addCleanup(patch.stopall) + stack.enter_context(patch("pylabrobot.io.serial.Serial", return_value=self.mock_serial)) + + self.backend = MolecularDevicesBackend(port="COM1") + self.backend.io = self.mock_serial + + self.send_command_mock = stack.enter_context( + patch.object(self.backend, "send_command", new_callable=AsyncMock) + ) async def test_setup_stop(self): + import sniffio + + if sniffio.current_async_library() == "trio": + pytest.skip("global_manager is not supported on trio") + # un-mock send_command for this test with patch.object( self.backend, "send_command", wraps=self.backend.send_command ) as wrapped_send_command: - await self.backend.setup() - self.mock_serial.setup.assert_called_once() - wrapped_send_command.assert_called_with("!") - await self.backend.stop() - self.mock_serial.stop.assert_called_once() + async with self.backend: + self.mock_serial.__aenter__.assert_called_once() + wrapped_send_command.assert_called_with("!") + self.mock_serial.__aexit__.assert_called_once() async def test_set_clear(self): await self.backend._set_clear() @@ -676,15 +685,16 @@ async def test_read_time_resolved_fluorescence( mock_transfer_data.assert_called_once() -class TestDataParsing(unittest.IsolatedAsyncioTestCase): +class TestDataParsing(AnyioTestBase): send_command_mock: AsyncMock - def setUp(self): - with patch("pylabrobot.io.serial.Serial", return_value=MagicMock()): - self.backend = MolecularDevicesBackend(port="COM1") - self.send_command_mock = patch.object( - self.backend, "send_command", new_callable=AsyncMock - ).start() + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) + stack.enter_context(patch("pylabrobot.io.serial.Serial", return_value=MagicMock())) + self.backend = MolecularDevicesBackend(port="COM1") + self.send_command_mock = stack.enter_context( + patch.object(self.backend, "send_command", new_callable=AsyncMock) + ) def test_parse_absorbance_single_wavelength(self): data_str = """ @@ -933,17 +943,18 @@ def data_generator(): self.assertEqual(result[1]["time"], 12355.6) -class TestErrorHandling(unittest.IsolatedAsyncioTestCase): - def setUp(self): +class TestErrorHandling(AnyioTestBase): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.mock_serial = MagicMock() self.mock_serial.setup = AsyncMock() self.mock_serial.stop = AsyncMock() self.mock_serial.write = AsyncMock() self.mock_serial.readline = AsyncMock() - with patch("pylabrobot.io.serial.Serial", return_value=self.mock_serial): - self.backend = MolecularDevicesBackend(port="/dev/tty01") - self.backend.io = self.mock_serial + stack.enter_context(patch("pylabrobot.io.serial.Serial", return_value=self.mock_serial)) + self.backend = MolecularDevicesBackend(port="/dev/tty01") + self.backend.io = self.mock_serial async def _mock_send_command_response(self, response_str: str): self.mock_serial.readline.side_effect = [response_str.encode() + b">\r\n"] @@ -995,7 +1006,3 @@ async def test_parse_basic_errors_ok_response(self): self.assertEqual(response, ["OK"]) except MolecularDevicesError: self.fail("MolecularDevicesError raised for a valid OK response") - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_reading/tecan/infinite_backend.py b/pylabrobot/plate_reading/tecan/infinite_backend.py index 992e422c342..b435217f112 100644 --- a/pylabrobot/plate_reading/tecan/infinite_backend.py +++ b/pylabrobot/plate_reading/tecan/infinite_backend.py @@ -6,7 +6,6 @@ from __future__ import annotations -import asyncio import logging import math import re @@ -15,6 +14,9 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.binary import Reader from pylabrobot.io.usb import USB from pylabrobot.plate_reading.backend import PlateReaderBackend @@ -517,8 +519,6 @@ def __init__( self.counts_per_mm_x = counts_per_mm_x self.counts_per_mm_y = counts_per_mm_y self.counts_per_mm_z = counts_per_mm_z - self._setup_lock = asyncio.Lock() - self._ready = False self._read_chunk_size = 512 self._max_row_wait_s = 300.0 self._mode_capabilities: Dict[str, Dict[str, str]] = {} @@ -527,26 +527,20 @@ def __init__( self._run_active = False self._active_step_loss_commands: List[str] = [] - async def setup(self) -> None: - async with self._setup_lock: - if self._ready: - return - await self.io.setup() - await self._initialize_device() - for mode in self._MODE_CAPABILITY_COMMANDS: - if mode not in self._mode_capabilities: - await self._query_mode_capabilities(mode) - self._ready = True - - async def stop(self) -> None: - async with self._setup_lock: - if not self._ready: - return + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) + await self._initialize_device() + for mode in self._MODE_CAPABILITY_COMMANDS: + if mode not in self._mode_capabilities: + await self._query_mode_capabilities(mode) + + async def cleanup(): await self._cleanup_protocol() - await self.io.stop() self._mode_capabilities.clear() self._reset_stream_state() - self._ready = False + + stack.push_shielded_async_callback(cleanup) async def open(self) -> None: """Open the reader drawer.""" @@ -902,19 +896,24 @@ async def _await_measurements( target = decoder.count + row_count start_count = decoder.count self._drain_pending_bin_events(decoder) - start = time.monotonic() + start_time = anyio.current_time() reads = 0 - while decoder.count < target and (time.monotonic() - start) < self._max_row_wait_s: - chunk = await self._read_packet(self._read_chunk_size) - if not chunk: - raise RuntimeError(f"{mode} read returned empty chunk; transport may not support reads.") - decoder.feed(chunk) - reads += 1 - if decoder.count < target: + try: + with anyio.fail_after(self._max_row_wait_s): + while decoder.count < target: + chunk = await self._read_packet(self._read_chunk_size) + if not chunk: + raise RuntimeError( + f"{mode} read returned empty chunk; transport may not support reads." + ) + decoder.feed(chunk) + reads += 1 + except TimeoutError: got = decoder.count - start_count + elapsed = anyio.current_time() - start_time raise RuntimeError( f"Timed out while parsing {mode.lower()} results " - f"(decoded {got}/{row_count} measurements in {time.monotonic() - start:.1f}s, {reads} reads)." + f"(decoded {got}/{row_count} measurements in {elapsed:.1f}s, {reads} reads)." ) def _drain_pending_bin_events(self, decoder: "_MeasurementDecoder") -> None: @@ -1039,9 +1038,7 @@ async def _read_packet(self, size: int) -> bytes: async def _recover_transport(self) -> None: try: - await self.io.stop() - await asyncio.sleep(0.2) - await self.io.setup() + await self.io.recover_transport() except Exception: logger.warning("Transport recovery failed.", exc_info=True) return diff --git a/pylabrobot/plate_reading/tecan/infinite_backend_tests.py b/pylabrobot/plate_reading/tecan/infinite_backend_tests.py index 59e3f797bec..a16f221e732 100644 --- a/pylabrobot/plate_reading/tecan/infinite_backend_tests.py +++ b/pylabrobot/plate_reading/tecan/infinite_backend_tests.py @@ -670,8 +670,6 @@ def _frame(self, command: str) -> bytes: return ExperimentalTecanInfinite200ProBackend._frame_command(command) async def test_open(self): - self.backend._ready = True - await self.backend.open() self.mock_usb.write.assert_has_calls( @@ -682,8 +680,6 @@ async def test_open(self): ) async def test_close(self): - self.backend._ready = True - await self.backend.close(self.plate) self.mock_usb.write.assert_has_calls( @@ -695,7 +691,6 @@ async def test_close(self): async def test_read_absorbance_commands(self): """Test that read_absorbance sends the correct configuration commands.""" - self.backend._ready = True async def mock_await(decoder, row_count, mode): cal_len, cal_blob = _abs_calibration_blob(6000, 0, 1000, 0, 1000) @@ -750,7 +745,6 @@ async def mock_await(decoder, row_count, mode): ) async def test_read_absorbance_uses_late_pending_calibration(self): - self.backend._ready = True terminal_calls = 0 async def mock_await(decoder, row_count, mode): @@ -772,7 +766,6 @@ async def mock_terminal(_saw_terminal): self.assertAlmostEqual(result[0]["data"][0][0], 0.3010299956639812) async def test_read_absorbance_subset_prepositions_to_masked_row_start(self): - self.backend._ready = True wells = self.plate.get_wells(["A2", "A3", "B1", "B2"]) async def mock_await(decoder, row_count, mode): @@ -808,7 +801,6 @@ async def mock_await(decoder, row_count, mode): async def test_read_fluorescence_commands(self): """Test that read_fluorescence sends the correct configuration commands.""" - self.backend._ready = True async def mock_await(decoder, row_count, mode): cal_len, cal_blob = _flr_calibration_blob(4850, 0, 0, 1000) @@ -883,7 +875,6 @@ async def mock_await(decoder, row_count, mode): async def test_read_luminescence_commands(self): """Test that read_luminescence sends the correct configuration commands.""" - self.backend._ready = True async def mock_await(decoder, row_count, mode): cal_blob = bytes(14) @@ -943,7 +934,6 @@ async def mock_await(decoder, row_count, mode): async def test_read_luminescence_defaults_focal_height_to_20mm(self): """Test that read_luminescence defaults focal height to 20 mm.""" - self.backend._ready = True async def mock_await(decoder, row_count, mode): cal_blob = bytes(14) diff --git a/pylabrobot/plate_reading/tecan/spark20m/spark_backend.py b/pylabrobot/plate_reading/tecan/spark20m/spark_backend.py index e16cfc5f4c7..5cf09a848da 100644 --- a/pylabrobot/plate_reading/tecan/spark20m/spark_backend.py +++ b/pylabrobot/plate_reading/tecan/spark20m/spark_backend.py @@ -3,6 +3,9 @@ import time from typing import Dict, List, Optional +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.plate_reading.backend import PlateReaderBackend from pylabrobot.plate_reading.utils import _get_min_max_row_col_tuples from pylabrobot.resources.plate import Plate @@ -48,9 +51,11 @@ def __init__(self, vid: int = 0x0C47) -> None: self.sensor_control = SensorControl(self.reader.send_command) self.data_control = DataControl(self.reader.send_command) - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: """Set up the plate reader.""" - await self.reader.connect() + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.reader) + await self.config_control.init_module() await self.data_control.turn_all_interval_messages_off() @@ -72,10 +77,6 @@ async def get_average_temperature(self) -> Optional[float]: return statistics.mean(temps) / 100.0 - async def stop(self) -> None: - """Close connections.""" - await self.reader.close() - async def open(self) -> None: """Move the plate carrier out.""" await self.plate_control.move_to_position(PlatePosition.OUT_RIGHT) @@ -143,25 +144,23 @@ async def read_absorbance( FilterType.BANDPASS, wavelength=wavelength * 10, bandwidth=bandwidth, label=1 ) - # Start Background Read - bg_task, stop_event, results = await self.reader.start_background_read(SparkDevice.ABSORPTION) - - if bg_task is None or stop_event is None or results is None: - raise RuntimeError(f"Failed to start background read for {SparkDevice.ABSORPTION.name}") - + # Background Read + results = None try: - # Execute Measurement Sequence - await self.measurement_control.prepare_instrument(measure_reference=True) + async with self.reader.background_read(SparkDevice.ABSORPTION) as results: + if results is None: + raise RuntimeError(f"Failed to start background read for {SparkDevice.ABSORPTION.name}") - await self.scan_plate_range(plate, wells) - measurement_time = time.time() + # Execute Measurement Sequence + await self.measurement_control.prepare_instrument(measure_reference=True) + await self.scan_plate_range(plate, wells) + measurement_time = time.time() finally: - stop_event.set() - await bg_task - - await self.data_control.turn_all_interval_messages_off() - await self.measurement_control.end_measurement() + if results is not None: + with anyio.CancelScope(shield=True): + await self.data_control.turn_all_interval_messages_off() + await self.measurement_control.end_measurement() # Process results data_matrix = process_absorbance(results) @@ -232,24 +231,23 @@ async def read_fluorescence( await self.optics_control.set_signal_gain(gain) await self.measurement_control.set_number_of_reads(num_reads) - # Start Background Read - bg_task, stop_event, results = await self.reader.start_background_read(SparkDevice.FLUORESCENCE) - - if bg_task is None or stop_event is None or results is None: - raise RuntimeError(f"Failed to start background read for {SparkDevice.FLUORESCENCE.name} ") - + # Background Read + results = None try: - # Execute Measurement Sequence - await self.measurement_control.prepare_instrument(measure_reference=True) - await self.scan_plate_range(plate, wells, focal_height) - measurement_time = time.time() + async with self.reader.background_read(SparkDevice.FLUORESCENCE) as results: + if results is None: + raise RuntimeError(f"Failed to start background read for {SparkDevice.FLUORESCENCE.name}") - finally: - stop_event.set() - await bg_task + # Execute Measurement Sequence + await self.measurement_control.prepare_instrument(measure_reference=True) - await self.data_control.turn_all_interval_messages_off() - await self.measurement_control.end_measurement() + await self.scan_plate_range(plate, wells) + measurement_time = time.time() + finally: + if results is not None: + with anyio.CancelScope(shield=True): + await self.data_control.turn_all_interval_messages_off() + await self.measurement_control.end_measurement() # Process results data_matrix = process_fluorescence(results) diff --git a/pylabrobot/plate_reading/tecan/spark20m/spark_backend_tests.py b/pylabrobot/plate_reading/tecan/spark20m/spark_backend_tests.py index 441935c4097..3460003fb9b 100644 --- a/pylabrobot/plate_reading/tecan/spark20m/spark_backend_tests.py +++ b/pylabrobot/plate_reading/tecan/spark20m/spark_backend_tests.py @@ -1,19 +1,19 @@ -import asyncio +import contextlib import sys -import unittest from unittest.mock import AsyncMock, MagicMock, patch from pylabrobot.plate_reading.tecan.spark20m.enums import SparkDevice from pylabrobot.plate_reading.tecan.spark20m.spark_backend import ExperimentalSparkBackend from pylabrobot.resources.plate import Plate from pylabrobot.resources.well import Well +from pylabrobot.testing.concurrency import AnyioTestBase sys.modules["usb.core"] = MagicMock() sys.modules["usb.util"] = MagicMock() -class TestExperimentalSparkBackend(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self) -> None: +class TestExperimentalSparkBackend(AnyioTestBase): + async def _enter_lifespan(self, stack) -> None: # Patch SparkReaderAsync self.reader_patcher = patch( "pylabrobot.plate_reading.tecan.spark20m.spark_backend.SparkReaderAsync" @@ -44,27 +44,30 @@ async def asyncSetUp(self) -> None: SparkDevice.PLATE_TRANSPORT: MagicMock(), } - async def asyncTearDown(self) -> None: - self.reader_patcher.stop() - self.abs_proc_patcher.stop() - self.fluo_proc_patcher.stop() + # Register cleanups + @stack.callback + def cleanup(): + self.reader_patcher.stop() + self.abs_proc_patcher.stop() + self.fluo_proc_patcher.stop() async def test_setup(self) -> None: - await self.backend.setup() - self.mock_reader.connect.assert_called_once() - # Verify that send_command was called for init_module - self.mock_reader.send_command.assert_called() + async with self.backend: + # Verify that send_command was called for init_module + self.mock_reader.send_command.assert_called() async def test_open(self) -> None: - await self.backend.open() - self.mock_reader.send_command.assert_called() + async with self.backend: + await self.backend.open() + self.mock_reader.send_command.assert_called() async def test_read_absorbance(self) -> None: # Mock background read - stop_event = MagicMock() - bg_task: "asyncio.Future[None]" = asyncio.Future() - bg_task.set_result(None) - self.mock_reader.start_background_read = AsyncMock(return_value=(bg_task, stop_event, [])) + @contextlib.asynccontextmanager + async def mock_bg_read(device_type): + yield [] + + self.mock_reader.background_read = mock_bg_read self.mock_process_absorbance.return_value = [[0.5]] @@ -93,10 +96,11 @@ async def test_read_absorbance(self) -> None: async def test_read_fluorescence(self) -> None: # Mock background read - stop_event = MagicMock() - bg_task: "asyncio.Future[None]" = asyncio.Future() - bg_task.set_result(None) - self.mock_reader.start_background_read = AsyncMock(return_value=(bg_task, stop_event, [])) + @contextlib.asynccontextmanager + async def mock_bg_read(device_type): + yield [] + + self.mock_reader.background_read = mock_bg_read self.mock_process_fluorescence.return_value = [[100.0]] @@ -141,7 +145,3 @@ async def test_get_average_temperature_empty(self) -> None: self.mock_reader.msgs = [] temp = await self.backend.get_average_temperature() self.assertIsNone(temp) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_reading/tecan/spark20m/spark_processor_tests.py b/pylabrobot/plate_reading/tecan/spark20m/spark_processor_tests.py index 8eda39ac690..b59fa530a20 100644 --- a/pylabrobot/plate_reading/tecan/spark20m/spark_processor_tests.py +++ b/pylabrobot/plate_reading/tecan/spark20m/spark_processor_tests.py @@ -362,7 +362,3 @@ def test_process_real_data(self) -> None: assert len(proc) == len(res) for proc_row, res_row in zip(proc, res): assert proc_row == pytest.approx(res_row) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async.py b/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async.py index 83a1d3256b0..14bcdcc06d2 100644 --- a/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async.py +++ b/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async.py @@ -1,7 +1,9 @@ -import asyncio +import contextlib +import functools import logging -import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, AsyncIterator, Dict, List, Optional + +import anyio try: import usb.core @@ -9,6 +11,7 @@ except ImportError: pass +from pylabrobot.concurrency import AsyncExitStackWithShielding, AsyncResource from pylabrobot.io.usb import USB from .enums import DEVICE_ENDPOINTS, VENDOR_ID, SparkDevice, SparkEndpoint @@ -19,17 +22,17 @@ class SparkError(Exception): """Error returned by the Spark device in a RespError packet.""" -class SparkReaderAsync: +class SparkReaderAsync(AsyncResource): def __init__(self, vid: int = VENDOR_ID) -> None: self.vid: int = vid self.devices: Dict[SparkDevice, USB] = {} # Per-device discovered endpoints, overriding DEVICE_ENDPOINTS from enums.py self.device_endpoints: Dict[SparkDevice, Dict[str, SparkEndpoint]] = {} self.seq_num: int = 0 - self.lock: asyncio.Lock = asyncio.Lock() + self.lock: anyio.Lock = anyio.Lock() self.msgs: List[Any] = [] - async def connect(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: logging.info(f"Scanning for devices with VID={hex(self.vid)}...") for device_type in SparkDevice: @@ -76,8 +79,10 @@ def configure(dev: usb.core.Device) -> None: write_endpoint_address=endpoints["write"].value, ) - await reader.setup(empty_buffer=False) # type: ignore[no-untyped-call] + # Use stack to manage the USB resource + await stack.enter_async_context(reader) self.devices[device_type] = reader + stack.callback(functools.partial(self.devices.pop, device_type, None)) # Discover actual endpoints from the USB descriptor, overriding the # hardcoded DEVICE_ENDPOINTS values for this specific hardware. @@ -171,50 +176,43 @@ async def _read_packet_in_executor( size: Optional[int] = None, timeout: Optional[float] = None, ) -> Optional[bytes]: - loop = asyncio.get_running_loop() - if reader._executor is None: - raise RuntimeError("Call setup() first.") - - start_time = time.monotonic() - - while True: - # Calculate remaining timeout if a timeout is set - current_timeout = timeout - if timeout is not None: - elapsed = time.monotonic() - start_time - if elapsed > timeout: - return None # Timeout - current_timeout = timeout - elapsed - - data = await loop.run_in_executor( - reader._executor, - lambda: reader._read_packet(size=size, timeout=current_timeout, endpoint=endpoint), - ) - if data is None: - return None + async def do_read() -> Optional[bytes]: + while True: + # reader._read_packet is async + data = await reader._read_packet(size=size, timeout=timeout, endpoint=endpoint) - # Validation Logic - if len(data) < 5: # Header(4) + Checksum(1) min - logging.warning(f"Packet too short ({len(data)}), ignoring: {data.hex()}") - continue + if data is None: + return None - # Check indicator - if data[0] not in PACKET_TYPE: - logging.warning(f"Invalid packet indicator {data[0]}, ignoring: {data.hex()}") - continue + # Validation Logic + if len(data) < 5: # Header(4) + Checksum(1) min + logging.warning(f"Packet too short ({len(data)}), ignoring: {data.hex()}") + continue - # Check length - # bytes 2-3 are payload length (Big Endian) - payload_len = (data[2] << 8) | data[3] - expected_len = 4 + payload_len + 1 # Header + Payload + Checksum - if len(data) < expected_len: - logging.warning( - f"Packet data shorter than payload length (got {len(data)}, expected {expected_len}), ignoring: {data.hex()}" - ) - continue + # Check indicator + if data[0] not in PACKET_TYPE: + logging.warning(f"Invalid packet indicator {data[0]}, ignoring: {data.hex()}") + continue + + # Check length + # bytes 2-3 are payload length (Big Endian) + payload_len = (data[2] << 8) | data[3] + expected_len = 4 + payload_len + 1 # Header + Payload + Checksum + if len(data) < expected_len: + logging.warning( + f"Packet data shorter than payload length (got {len(data)}, expected {expected_len}), ignoring: {data.hex()}" + ) + continue + + return data - return data + if timeout is not None: + with anyio.move_on_after(timeout): + return await do_read() + return None + else: + return await do_read() async def send_command( self, @@ -228,29 +226,27 @@ async def send_command( reader = self.devices[device_type] async with self.lock: - # Set up read task before sending command - read_task = self._init_read(reader) - await asyncio.sleep(0.01) + try: + response = None + async with anyio.create_task_group() as tg: - response_task = asyncio.create_task(self._get_response(read_task, reader, timeout=timeout)) + async def get_resp(): + nonlocal response + response = await self._get_response(reader, timeout=timeout) - try: - logging.debug(f"Sending to {device_type.name}: {command_str}") - payload = command_str.encode("ascii") - payload_len = len(payload) + tg.start_soon(get_resp) - header = bytes([0x01, self.seq_num, 0x00, payload_len]) - message = header + payload + bytes([self._calculate_checksum(header + payload)]) - self.seq_num = (self.seq_num + 1) % 256 + logging.debug(f"Sending to {device_type.name}: {command_str}") + payload = command_str.encode("ascii") + payload_len = len(payload) - await reader.write(message) - logging.debug(f"Sent message to {device_type.name}: {message.hex()}") + header = bytes([0x01, self.seq_num, 0x00, payload_len]) + message = header + payload + bytes([self._calculate_checksum(header + payload)]) + self.seq_num = (self.seq_num + 1) % 256 - # Wait for response - if not response_task.done(): - await response_task + await reader.write(message) + logging.debug(f"Sent message to {device_type.name}: {message.hex()}") - response = response_task.result() logging.debug(f"Response: {response}") return ( response["payload"]["message"] @@ -260,38 +256,19 @@ async def send_command( except Exception as e: logging.error(f"Error in send_command to {device_type.name}: {e}", exc_info=True) raise - finally: - if not response_task.done(): - response_task.cancel() - try: - await response_task - except asyncio.CancelledError: - pass - - def _init_read( - self, - reader: USB, - count: int = 512, - read_timeout: int = 2000, - ) -> "asyncio.Future[Any]": - # Convert read_timeout from milliseconds to seconds for USB class. - return asyncio.ensure_future( - self._read_packet_in_executor( - reader=reader, - endpoint=None, - size=count, - timeout=read_timeout / 1000.0, - ) - ) async def _get_response( self, - read_task: "asyncio.Future[Any]", reader: USB, timeout: float = 60.0, ) -> Optional[Dict[str, Any]]: try: - data = await read_task + data = await self._read_packet_in_executor( + reader=reader, + endpoint=None, + size=512, + timeout=2000 / 1000.0, + ) if data is None: logging.warning("Read task returned None") @@ -313,35 +290,33 @@ async def _get_response( elif parsed.get("type") == "RespError": raise SparkError(parsed) - deadline = time.monotonic() + timeout - while parsed.get("type") != "RespReady" and time.monotonic() < deadline: - try: - await asyncio.sleep(0.01) - logging.debug(f"Still busy, retrying... time left: {deadline - time.monotonic():.1f}s") - - resp = await self._read_packet_in_executor( - reader=reader, endpoint=None, size=512, timeout=0.02 - ) - - if resp: - logging.debug(f"Read task completed ({len(resp)} bytes): {bytes(resp).hex()}") - parsed = parse_single_spark_packet(bytes(resp)) - logging.debug(f"Parsed: {parsed}") - if parsed.get("type") == "RespMessage": - self.msgs.append(parsed["payload"]) - elif parsed.get("type") == "RespError": - raise SparkError(parsed) - except SparkError: - raise - except Exception as e: - logging.error(f"Error in get_response retry: {e}") - if parsed.get("type") != "RespReady": + try: + with anyio.fail_after(timeout): + while parsed.get("type") != "RespReady": + try: + await anyio.sleep(0.01) + logging.debug("Still busy, retrying...") + + resp = await self._read_packet_in_executor( + reader=reader, endpoint=None, size=512, timeout=0.02 + ) + + if resp: + logging.debug(f"Read task completed ({len(resp)} bytes): {bytes(resp).hex()}") + parsed = parse_single_spark_packet(bytes(resp)) + logging.debug(f"Parsed: {parsed}") + if parsed.get("type") == "RespMessage": + self.msgs.append(parsed["payload"]) + elif parsed.get("type") == "RespError": + raise SparkError(parsed) + except SparkError: + raise + except Exception as e: + logging.error(f"Error in get_response retry: {e}") + except TimeoutError: logging.warning('Timeout waiting for "RespReady" response') return parsed - except asyncio.CancelledError: - logging.warning("Read task was cancelled") - return None except SparkError: raise except Exception as e: @@ -352,17 +327,18 @@ def clear_messages(self) -> None: """Clear the list of recorded RespMessage payloads.""" self.msgs = [] - async def start_background_read( + @contextlib.asynccontextmanager + async def background_read( self, device_type: SparkDevice, read_timeout: int = 100, - ) -> Tuple[Optional["asyncio.Task[None]"], Optional[asyncio.Event], Optional[List[bytes]]]: + ) -> AsyncIterator[Optional[List[bytes]]]: if device_type not in self.devices: logging.error(f"Device type {device_type} not connected.") - return None, None, None + yield None + return reader = self.devices[device_type] - stop_event = asyncio.Event() results: List[bytes] = [] endpoints = self.get_endpoints(device_type) endpoint = endpoints["read_data"] @@ -371,8 +347,8 @@ async def background_reader() -> None: logging.info( f"Starting background reader for {device_type.name} {endpoint.name} (0x{endpoint.value:02x})" ) - while not stop_event.is_set(): - await asyncio.sleep(0.2) # Avoid tight loop + while True: + await anyio.sleep(0.2) # Avoid tight loop try: # timeout in seconds data = await self._read_packet_in_executor( @@ -384,22 +360,13 @@ async def background_reader() -> None: if data: results.append(bytes(data)) logging.debug(f"Background read {len(data)} bytes: {bytes(data).hex()}") - except asyncio.CancelledError: - logging.info("Background reader cancelled.") - break except Exception as e: logging.error(f"Error in background reader: {e}", exc_info=True) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) logging.info(f"Stopping background reader for {device_type.name} {endpoint.name}") - task = asyncio.create_task(background_reader()) - return task, stop_event, results - - async def close(self) -> None: - for device_type, reader in self.devices.items(): - try: - await reader.stop() # type: ignore[no-untyped-call] - logging.info(f"{device_type.name} resources released.") - except Exception as e: - logging.error(f"Error closing {device_type.name}: {e}") - self.devices = {} + async with contextlib.AsyncExitStack() as stack: + tg = await stack.enter_async_context(anyio.create_task_group()) + stack.callback(tg.cancel_scope.cancel) + tg.start_soon(background_reader) + yield results diff --git a/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async_tests.py b/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async_tests.py index db0bab73ff8..5dd1adb70b3 100644 --- a/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async_tests.py +++ b/pylabrobot/plate_reading/tecan/spark20m/spark_reader_async_tests.py @@ -1,8 +1,8 @@ -import asyncio import concurrent.futures import unittest from unittest.mock import AsyncMock, MagicMock, patch +import anyio import pytest pytest.importorskip("usb") @@ -26,9 +26,17 @@ async def asyncTearDown(self) -> None: async def test_connect_success(self) -> None: # Create a mock USB instance mock_usb_instance = AsyncMock() + mock_usb_instance.__aenter__.return_value = mock_usb_instance + mock_usb_instance.__aexit__.return_value = None + mock_usb_instance.dev = MagicMock() # Ensure dev is synchronous self.mock_usb_class.return_value = mock_usb_instance - await self.reader.connect() + async with self.reader: + # Check that it's in devices + # Note: connect iterates all enum members. If all succeed, all are in devices. + # We mock success for all. + self.assertIn(SparkDevice.PLATE_TRANSPORT, self.reader.devices) + self.assertEqual(self.reader.devices[SparkDevice.PLATE_TRANSPORT], mock_usb_instance) # Verify USB initialized for known devices (we iterate all SparkDevices) # Just check for one of them @@ -37,25 +45,14 @@ async def test_connect_success(self) -> None: _, kwargs = self.mock_usb_class.call_args self.assertEqual(kwargs["read_endpoint_address"], SparkEndpoint.INTERRUPT_IN.value) self.assertEqual(kwargs["write_endpoint_address"], SparkEndpoint.BULK_OUT.value) - # Check that endpoint addresses were passed - _, kwargs = self.mock_usb_class.call_args - self.assertEqual(kwargs["read_endpoint_address"], SparkEndpoint.INTERRUPT_IN.value) - self.assertEqual(kwargs["write_endpoint_address"], SparkEndpoint.BULK_OUT.value) - - # Check that it's in devices - # Note: connect iterates all enum members. If all succeed, all are in devices. - # We mock success for all. - self.assertIn(SparkDevice.PLATE_TRANSPORT, self.reader.devices) - self.assertEqual(self.reader.devices[SparkDevice.PLATE_TRANSPORT], mock_usb_instance) - - mock_usb_instance.setup.assert_awaited() async def test_connect_no_devices(self) -> None: # USB raising RuntimeError means device not found self.mock_usb_class.side_effect = RuntimeError("Device not found") with self.assertRaisesRegex(ValueError, "Failed to connect to any known Spark devices"): - await self.reader.connect() + async with self.reader: + pass async def test_connect_usb_error(self) -> None: # Device 1: Fails with Exception (not RuntimeError) @@ -65,6 +62,9 @@ async def test_connect_usb_error(self) -> None: # based on input arguments (id_product). mock_usb_success = AsyncMock() + mock_usb_success.dev = MagicMock() # Ensure dev is synchronous + mock_usb_success.__aenter__.return_value = mock_usb_success + mock_usb_success.__aexit__.return_value = None def side_effect( id_vendor: int, @@ -81,12 +81,11 @@ def side_effect( self.mock_usb_class.side_effect = side_effect - await self.reader.connect() - - # Device 1 should not be in devices - self.assertNotIn(SparkDevice.PLATE_TRANSPORT, self.reader.devices) - # Device 2 should be in devices - self.assertIn(SparkDevice.ABSORPTION, self.reader.devices) + async with self.reader: + # Device 1 should not be in devices + self.assertNotIn(SparkDevice.PLATE_TRANSPORT, self.reader.devices) + # Device 2 should be in devices + self.assertIn(SparkDevice.ABSORPTION, self.reader.devices) async def test_send_command(self) -> None: # Setup connected device @@ -109,8 +108,8 @@ def execute_sync(func, *args): mock_dev._executor.submit.side_effect = execute_sync - # Mock _read_packet to avoid TypeError in background task (must be MagicMock, not AsyncMock) - mock_dev._read_packet = MagicMock() + # Mock _read_packet to avoid TypeError in background task (must be AsyncMock) + mock_dev._read_packet = AsyncMock() mock_dev._read_packet.return_value = b"\x81\x00\x00\x00\x00" # Mock calculate_checksum to return a predictable value @@ -142,26 +141,24 @@ async def test_get_response_success(self) -> None: ) as mock_parse: mock_parse.return_value = {"type": "RespReady", "payload": {"status": "OK"}} - async def return_bytes() -> bytes: - return b"\x81\x00\x00\x00\x00" - - read_task = asyncio.create_task(return_bytes()) - mock_reader = AsyncMock() - parsed = await self.reader._get_response(read_task, reader=mock_reader) + with patch.object(self.reader, "_read_packet_in_executor") as mock_read_exec: + mock_read_exec.return_value = b"\x81\x00\x00\x00\x00" + + parsed = await self.reader._get_response(reader=mock_reader) self.assertEqual(parsed, {"type": "RespReady", "payload": {"status": "OK"}}) async def test_get_response_busy_then_ready(self) -> None: # This tests the retry loop mock_reader = AsyncMock() - mock_reader._read_packet = MagicMock() + mock_reader._read_packet = AsyncMock() with patch( "pylabrobot.plate_reading.tecan.spark20m.spark_reader_async.parse_single_spark_packet" ) as mock_parse: # Sequence of parse results: - # 1. First read (passed as task): RespMessage (busy/intermediate) + # 1. First read: RespMessage (busy/intermediate) # 2. Retry read 1: RespReady mock_parse.side_effect = [ {"type": "RespMessage", "payload": "msg1"}, @@ -185,22 +182,20 @@ def execute_sync(func, *args): # First call inside _get_response (via executor) mock_reader._read_packet.return_value = b"\x81\x00\x00\x00\x00" - async def return_initial_data() -> bytes: - return b"\x81\x00\x00\x00\x00" - - read_task = asyncio.create_task(return_initial_data()) + with patch.object(self.reader, "_read_packet_in_executor") as mock_read_exec: + mock_read_exec.side_effect = [ + b"\x81\x00\x00\x00\x00", # Initial read + b"\x81\x00\x00\x00\x00", # Retry read + ] - parsed = await self.reader._get_response(read_task, reader=mock_reader, timeout=1.0) + parsed = await self.reader._get_response(reader=mock_reader, timeout=1.0) self.assertEqual(parsed, {"type": "RespReady", "payload": "done"}) self.assertIn("msg1", self.reader.msgs) - # Should have called _read_packet once for the retry - mock_reader._read_packet.assert_called() - - async def test_start_background_read(self) -> None: + async def test_background_read(self) -> None: mock_dev = AsyncMock() - mock_dev._read_packet = MagicMock() + mock_dev._read_packet = AsyncMock() self.reader.devices[SparkDevice.ABSORPTION] = mock_dev mock_dev._executor = MagicMock() @@ -235,37 +230,26 @@ def execute_sync(func, *args): # Let's return DATA1, DATA2, then b"" (too short) repeatedly. mock_dev._read_packet.side_effect = [DATA1, DATA2] + [b""] * 100 - # We also need to configure find_descriptor for `_read_from_endpoint` if size is None. - # But start_background_read passes size=1024. - - task, stop_event, results = await self.reader.start_background_read(SparkDevice.ABSORPTION) - - assert task is not None - assert stop_event is not None - assert results is not None - - # Let it run to collect data - await asyncio.sleep(0.5) # Wait for 2 reads (0.2 sleep in loop) - - stop_event.set() - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + async with self.reader.background_read(SparkDevice.ABSORPTION) as results: + assert results is not None + # Let it run to collect data + await anyio.sleep(0.5) # Wait for 2 reads (0.2 sleep in loop) self.assertIn(DATA1, results) self.assertIn(DATA2, results) async def test_close(self) -> None: - mock_dev = AsyncMock() - self.reader.devices[SparkDevice.PLATE_TRANSPORT] = mock_dev + mock_usb_instance = AsyncMock() + mock_usb_instance.__aenter__.return_value = mock_usb_instance + mock_usb_instance.__aexit__.return_value = None + mock_usb_instance.dev = MagicMock() # Ensure dev is synchronous + self.mock_usb_class.return_value = mock_usb_instance - await self.reader.close() + async with self.reader: + pass - self.assertEqual(self.reader.devices, {}) - # Ensure stop called on the mocked USB device - mock_dev.stop.assert_awaited() + # Ensure resources released via context manager exit + mock_usb_instance.__aexit__.assert_awaited() async def test_get_response_error(self) -> None: with patch( @@ -273,19 +257,17 @@ async def test_get_response_error(self) -> None: ) as mock_parse: mock_parse.return_value = {"type": "RespError", "payload": {"error": "BadCommand"}} - async def return_error_bytes() -> bytes: - return b"\x86\x00\x00\x00\x00" - - read_task = asyncio.create_task(return_error_bytes()) - mock_reader = AsyncMock() - with self.assertRaises(SparkError): - await self.reader._get_response(read_task, reader=mock_reader) + with patch.object(self.reader, "_read_packet_in_executor") as mock_read_exec: + mock_read_exec.return_value = b"\x86\x00\x00\x00\x00" + + with self.assertRaises(SparkError): + await self.reader._get_response(reader=mock_reader) async def test_get_response_empty_packet_retry(self) -> None: # Test that empty packet (ZLP) triggers retry mock_reader = AsyncMock() - mock_reader._read_packet = MagicMock() + mock_reader._read_packet = AsyncMock() # Configure mock executor and device for read retry mock_reader._executor = MagicMock() @@ -304,34 +286,24 @@ def execute_sync(func, *args): "pylabrobot.plate_reading.tecan.spark20m.spark_reader_async.parse_single_spark_packet" ) as mock_parse: # Sequence: - # 1. First read task returns empty bytes -> Triggers ValueError in parser (mocked below) -> retry + # 1. First read returns empty bytes -> Triggers ValueError in parser (mocked below) -> retry # 2. Retry read returns valid data -> Success - # Mock the retry read - mock_reader._read_packet.return_value = b"\x81\x00\x00\x00\x00" - - # Logic: - # _get_response awaits read_task -> returns b"" - # calls parse_single_spark_packet(b"") -> raises ValueError - # catches, loops. - # loop calls _read_from_endpoint -> returns b"retry_data" - # calls parse_single_spark_packet(b"retry_data") -> returns valid + # Mock the reads: initial returns empty, retries return valid + mock_reader._read_packet.side_effect = [ + b"", # Initial read + b"\x81\x00\x00\x00\x00", # First retry in loop + b"\x81\x00\x00\x00\x00", # Second retry in loop + ] mock_parse.side_effect = [ ValueError("Packet too short"), # First call with empty bytes {"type": "RespReady", "payload": "done"}, # Second call with retry_data ] - async def return_empty_bytes() -> bytes: - return b"" - - read_task = asyncio.create_task(return_empty_bytes()) - - parsed = await self.reader._get_response(read_task, reader=mock_reader, timeout=1.0) + parsed = await self.reader._get_response(reader=mock_reader, timeout=1.0) self.assertEqual(parsed, {"type": "RespReady", "payload": "done"}) - # Verify retry happened - mock_reader._read_packet.assert_called() async def test_read_packet_in_executor_retries(self) -> None: # Test that _read_packet_in_executor retries on invalid packets using new validation logic @@ -364,7 +336,7 @@ def execute_sync(func, *args): INVALID_TRUNCATED = b"\x81\x00\x00\x05\x00" # Payload len 5, but total len 5 (expect 4+5+1=10) VALID = b"\x81\x00\x00\x00\x00" - mock_reader._read_packet = MagicMock( + mock_reader._read_packet = AsyncMock( side_effect=[INVALID_SHORT, INVALID_INDICATOR, INVALID_TRUNCATED, VALID] ) @@ -386,7 +358,3 @@ def execute_sync(func, *args): self.assertEqual(mock_reader._read_packet.call_count, 4) finally: mock_reader._executor.shutdown() - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/backend.py b/pylabrobot/plate_washing/biotek/el406/backend.py index 15bb6b833f1..3bf2824bdbf 100644 --- a/pylabrobot/plate_washing/biotek/el406/backend.py +++ b/pylabrobot/plate_washing/biotek/el406/backend.py @@ -13,11 +13,13 @@ from __future__ import annotations -import asyncio import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.ftdi import FTDI from pylabrobot.machines.backend import MachineBackend from pylabrobot.resources import Plate @@ -68,13 +70,10 @@ def __init__( self.timeout = timeout self._device_id = device_id self.io: FTDI | None = None - self._command_lock: asyncio.Lock | None = None + self._command_lock: anyio.Lock | None = None self._in_batch: bool = False - async def setup( - self, - skip_reset: bool = False, - ) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, skip_reset: bool = False): """Set up communication with the EL406. Configures the FTDI USB interface with the correct parameters: @@ -83,25 +82,33 @@ async def setup( - No flow control (disabled) If ``self.io`` is already set (e.g. injected mock for testing), - it is used as-is and ``setup()`` is not called on it again. + it is used as-is. Note: This does NOT start a batch. Use ``batch()`` or call step commands directly (they auto-batch). Args: + stack: The AsyncExitStack to register cleanups with. skip_reset: If True, skip the instrument reset step. Raises: RuntimeError: If pylibftdi is not installed or communication fails. """ - self._command_lock = asyncio.Lock() + await super()._enter_lifespan(stack) + + self._command_lock = anyio.Lock() logger.info("BioTekEL406Backend setting up") logger.info(" Timeout: %.1f seconds", self.timeout) if self.io is None: self.io = FTDI(human_readable_device_name="BioTek EL406", device_id=self._device_id) - await self.io.setup() + + @stack.callback + def _cleanup(): + self.io = None + + await stack.enter_async_context(self.io) # Configure serial parameters logger.debug("Configuring serial parameters...") @@ -118,8 +125,6 @@ async def setup( await self.io.set_dtr(True) logger.debug(" RTS and DTR enabled") except Exception as e: - await self.io.stop() - self.io = None raise EL406CommunicationError( f"Failed to configure FTDI device: {e}", operation="configure", @@ -146,17 +151,6 @@ async def setup( logger.info("BioTekEL406Backend setup complete") - async def stop(self) -> None: - """Stop communication with the EL406. - - Closes the FTDI connection. Batch cleanup is handled by the ``batch()`` - context manager, not by ``stop()``. - """ - logger.info("BioTekEL406Backend stopping") - if self.io is not None: - await self.io.stop() - self.io = None - @asynccontextmanager async def batch(self, plate: Plate) -> AsyncIterator[None]: """Context manager for batching step commands. diff --git a/pylabrobot/plate_washing/biotek/el406/batch_tests.py b/pylabrobot/plate_washing/biotek/el406/batch_tests.py index b7ef54ec295..44d074a3dba 100644 --- a/pylabrobot/plate_washing/biotek/el406/batch_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/batch_tests.py @@ -1,6 +1,5 @@ # mypy: disable-error-code="union-attr,assignment,arg-type" -import unittest from pylabrobot.plate_washing.biotek.el406.mock_tests import PT96, EL406TestCase @@ -154,7 +153,3 @@ async def test_multiple_steps_in_batch_share_single_batch(self): # Two shake commands shake_count = sum(1 for d in written if len(d) >= 3 and d[2] == 0xA3) self.assertEqual(shake_count, 2, "Should have two SHAKE commands") - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/communication.py b/pylabrobot/plate_washing/biotek/el406/communication.py index df2be8f6ad4..5a113981d11 100644 --- a/pylabrobot/plate_washing/biotek/el406/communication.py +++ b/pylabrobot/plate_washing/biotek/el406/communication.py @@ -6,11 +6,11 @@ from __future__ import annotations -import asyncio import logging -import time from typing import TYPE_CHECKING, NamedTuple +import anyio + from pylabrobot.io.binary import Reader from .error_codes import get_error_message @@ -53,12 +53,12 @@ class EL406CommunicationMixin: Requires: self.io: FTDI IO wrapper instance self.timeout: Default timeout in seconds - self._command_lock: asyncio.Lock for command serialization + self._command_lock: anyio.Lock for command serialization """ io: FTDI | None timeout: float - _command_lock: asyncio.Lock | None + _command_lock: anyio.Lock | None async def _write_to_device(self, data: bytes) -> None: """Write bytes to the FTDI device, wrapping errors. @@ -76,50 +76,47 @@ async def _write_to_device(self, data: bytes) -> None: original_error=e, ) from e - async def _wait_for_ack(self, timeout: float, t0: float) -> None: - """Poll device for ACK byte within the remaining timeout window. + async def _wait_for_ack(self, timeout: float) -> None: + """Poll device for ACK byte within the timeout window. Args: - timeout: Total timeout budget in seconds. - t0: Start timestamp (from ``time.monotonic()``). + timeout: Timeout budget in seconds. Raises: RuntimeError: If device sends NAK. TimeoutError: If no ACK within timeout. """ assert self.io is not None - while time.monotonic() - t0 < timeout: - byte = await self.io.read(1) - if byte: - if byte[0] == 0x15: # NAK - raise RuntimeError( - f"Device rejected command (NAK). Response: {byte!r}. " - "This may indicate an invalid command, bad parameters, or device busy state." - ) - if byte[0] == 0x06: # ACK - return - await asyncio.sleep(0.01) - raise TimeoutError("Timeout waiting for ACK") - - async def _read_exact_bytes(self, count: int, timeout: float, t0: float) -> bytes: - """Read exactly *count* bytes from the device, polling until done or timeout. + with anyio.fail_after(timeout): + while True: + byte = await self.io.read(1) + if byte: + if byte[0] == 0x15: # NAK + raise RuntimeError( + f"Device rejected command (NAK). Response: {byte!r}. " + "This may indicate an invalid command, bad parameters, or device busy state." + ) + if byte[0] == 0x06: # ACK + return + await anyio.sleep(0.01) + + async def _read_exact_bytes(self, count: int) -> bytes: + """Read exactly *count* bytes from the device, polling until done. Args: count: Number of bytes to read. - timeout: Total timeout budget in seconds. - t0: Start timestamp (from ``time.monotonic()``). Returns: - Bytes read (may be shorter than *count* if timeout is reached). + Bytes read. """ assert self.io is not None buf = b"" - while len(buf) < count and time.monotonic() - t0 < timeout: + while len(buf) < count: chunk = await self.io.read(count - len(buf)) if chunk: buf += chunk else: - await asyncio.sleep(0.01) + await anyio.sleep(0.01) return buf async def _purge_buffers(self) -> None: @@ -246,28 +243,30 @@ async def _send_framed_command( logger.debug("Sent header: %s", header.hex()) if data: - await asyncio.sleep(0.001) # Small delay between header and data + await anyio.sleep(0.001) # Small delay between header and data await self._write_to_device(data) logger.debug("Sent data: %s", data.hex()) logger.debug("Sent framed: %s", framed_message.hex()) # Read full response: ACK + 11-byte header + variable data - await self._wait_for_ack(timeout, time.monotonic()) + await self._wait_for_ack(timeout) result = bytes([0x06]) - # Fresh timestamp after ACK — header + data share a single timeout budget. - t0 = time.monotonic() - resp_header = await self._read_exact_bytes(11, timeout, t0) - - if len(resp_header) == 11: - result += resp_header - # Parse data length from header bytes 7-8 (little-endian) - data_len = Reader(resp_header[7:]).u16() - response_data = await self._read_exact_bytes(data_len, timeout, t0) - result += response_data - logger.debug("Full response: %s (%d bytes)", result.hex(), len(result)) - else: - logger.debug("ACK-only response (no frame): %s", result.hex()) + try: + with anyio.fail_after(timeout): + resp_header = await self._read_exact_bytes(11) + + if len(resp_header) == 11: + result += resp_header + # Parse data length from header bytes 7-8 (little-endian) + data_len = Reader(resp_header[7:]).u16() + response_data = await self._read_exact_bytes(data_len) + result += response_data + logger.debug("Full response: %s (%d bytes)", result.hex(), len(result)) + else: + logger.debug("ACK-only response (no frame): %s", result.hex()) + except TimeoutError: + raise TimeoutError("Timeout reading response from EL406") from None return result @@ -312,24 +311,26 @@ async def _send_action_command( await self._write_to_device(header) if data: - await asyncio.sleep(0.001) + await anyio.sleep(0.001) await self._write_to_device(data) logger.debug("Sent action command: %s", framed_message.hex()) - t0 = time.monotonic() - - # Step 1: Wait for ACK (short timeout) - await self._wait_for_ack(min(timeout, self.timeout), t0) - logger.debug("Got ACK, waiting for completion...") - - # Step 2: Wait for completion frame (11-byte header + data) - header = await self._read_exact_bytes(11, timeout, t0) - if len(header) < 11: - raise TimeoutError(f"Timeout waiting for completion header (got {len(header)} bytes)") - - # Parse data length and read remaining data - data_len = Reader(header[7:]).u16() - data = await self._read_exact_bytes(data_len, timeout, t0) + try: + with anyio.fail_after(timeout): + # Step 1: Wait for ACK (short timeout) + await self._wait_for_ack(min(timeout, self.timeout)) + logger.debug("Got ACK, waiting for completion...") + + # Step 2: Wait for completion frame (11-byte header + data) + header = await self._read_exact_bytes(11) + if len(header) < 11: + raise TimeoutError(f"Timeout waiting for completion header (got {len(header)} bytes)") + + # Parse data length and read remaining data + data_len = Reader(header[7:]).u16() + data = await self._read_exact_bytes(data_len) + except TimeoutError: + raise TimeoutError("Timeout waiting for completion frame from EL406") from None result = header + data @@ -384,13 +385,13 @@ async def _send_framed_query( logger.debug("Sent query header 0x%04X: %s", command, msg_header.hex()) if msg_data: - await asyncio.sleep(0.001) + await anyio.sleep(0.001) await self._write_to_device(msg_data) logger.debug("Sent query data: %s", msg_data.hex()) # Wait for ACK try: - await self._wait_for_ack(timeout, time.monotonic()) + await self._wait_for_ack(timeout) except RuntimeError as e: raise RuntimeError( f"Device rejected command 0x{command:04X} (NAK). Check command code and parameters." @@ -398,20 +399,23 @@ async def _send_framed_query( except TimeoutError as e: raise TimeoutError(f"Timeout waiting for ACK (command 0x{command:04X})") from e - t0 = time.monotonic() - # Read 11-byte response header (shares timeout budget with data) - resp_header = await self._read_exact_bytes(11, timeout, t0) - if len(resp_header) < 11: - raise TimeoutError(f"Timeout reading response header (got {len(resp_header)}/11 bytes)") - - logger.debug("Response header: %s", resp_header.hex()) - - # Parse data length from header bytes 7-8 (little-endian) - data_len = Reader(resp_header[7:]).u16() - logger.debug("Response data length: %d", data_len) - - # Read data bytes - response_data = await self._read_exact_bytes(data_len, timeout, t0) + try: + with anyio.fail_after(timeout): + # Read 11-byte response header (shares timeout budget with data) + resp_header = await self._read_exact_bytes(11) + if len(resp_header) < 11: + raise TimeoutError(f"Timeout reading response header (got {len(resp_header)}/11 bytes)") + + logger.debug("Response header: %s", resp_header.hex()) + + # Parse data length from header bytes 7-8 (little-endian) + data_len = Reader(resp_header[7:]).u16() + logger.debug("Response data length: %d", data_len) + + # Read data bytes + response_data = await self._read_exact_bytes(data_len) + except TimeoutError: + raise TimeoutError("Timeout reading response from EL406") from None if len(response_data) < data_len: raise TimeoutError( f"Timeout reading response data (got {len(response_data)}/{data_len} bytes)" @@ -467,13 +471,17 @@ async def _wait_until_ready(self, timeout: float = 5.0, poll_interval: float = 0 Raises: TimeoutError: If the device stays busy beyond *timeout*. """ - t0 = time.monotonic() - while time.monotonic() - t0 < timeout: - poll = await self._poll_device_state() - if poll.state != STATE_RUNNING: - return - await asyncio.sleep(poll_interval) - raise TimeoutError(f"Device still busy (STATE_RUNNING) after {timeout}s waiting for readiness") + try: + with anyio.fail_after(timeout): + while True: + poll = await self._poll_device_state() + if poll.state != STATE_RUNNING: + return + await anyio.sleep(poll_interval) + except TimeoutError: + raise TimeoutError( + f"Device still busy (STATE_RUNNING) after {timeout}s waiting for readiness" + ) async def _send_step_command( self, @@ -523,36 +531,34 @@ async def _send_step_command( logger.debug("Step command sent, got initial response: %s", response.hex()) # 3. Initial delay before polling - await asyncio.sleep(0.5) + await anyio.sleep(0.5) # 4. Poll for completion - t0 = time.monotonic() poll_count = 0 - - logger.debug("Starting polling loop...") - - while time.monotonic() - t0 < timeout: - await asyncio.sleep(poll_interval) - poll_count += 1 - - poll = await self._poll_device_state() - logger.debug("Poll #%d: %d bytes", poll_count, len(poll.raw_response)) - - if poll.state in (STATE_INITIAL, STATE_STOPPED): - logger.debug("Step completed (state=%d) after %d polls", poll.state, poll_count) - if poll.validity != 0: - raise EL406DeviceError(poll.validity, get_error_message(poll.validity)) - return poll.raw_response - - if poll.state == STATE_RUNNING: - logger.debug("Step in progress (state=Running), continuing poll...") - elif poll.state == STATE_PAUSED: - logger.warning("Step is paused (state=3)") - elif poll.status == 0: - # Unknown state with status=0 means done - logger.debug("Done (unknown state=%d, status=0)", poll.state) - return poll.raw_response - else: - logger.debug("Unknown state=%d, status=%d, continuing...", poll.state, poll.status) - - raise TimeoutError(f"Timeout waiting for step completion after {timeout}s") + try: + with anyio.fail_after(timeout): + while True: + await anyio.sleep(poll_interval) + poll_count += 1 + + poll = await self._poll_device_state() + logger.debug("Poll #%d: %d bytes", poll_count, len(poll.raw_response)) + + if poll.state in (STATE_INITIAL, STATE_STOPPED): + logger.debug("Step completed (state=%d) after %d polls", poll.state, poll_count) + if poll.validity != 0: + raise EL406DeviceError(poll.validity, get_error_message(poll.validity)) + return poll.raw_response + + if poll.state == STATE_RUNNING: + logger.debug("Step in progress (state=Running), continuing poll...") + elif poll.state == STATE_PAUSED: + logger.warning("Step is paused (state=3)") + elif poll.status == 0: + # Unknown state with status=0 means done + logger.debug("Done (unknown state=%d, status=0)", poll.state) + return poll.raw_response + else: + logger.debug("Unknown state=%d, status=%d, continuing...", poll.state, poll.status) + except TimeoutError: + raise TimeoutError(f"Timeout waiting for step completion after {timeout}s") diff --git a/pylabrobot/plate_washing/biotek/el406/mock_tests.py b/pylabrobot/plate_washing/biotek/el406/mock_tests.py index 683e84e46b1..da49da41bca 100644 --- a/pylabrobot/plate_washing/biotek/el406/mock_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/mock_tests.py @@ -1,16 +1,17 @@ # mypy: disable-error-code="union-attr,assignment,arg-type,attr-defined" """Mock FTDI IO for EL406 testing.""" -import asyncio -import unittest from unittest.mock import patch +import anyio + from pylabrobot.plate_washing.biotek.el406 import ExperimentalBioTekEL406Backend from pylabrobot.resources import Plate from pylabrobot.resources.utils import create_ordered_items_2d from pylabrobot.resources.well import Well +from pylabrobot.testing.concurrency import AnyioTestBase -_real_sleep = asyncio.sleep +_real_sleep = anyio.sleep async def _noop(*a, **kw): @@ -50,22 +51,26 @@ def _make_plate(name: str, num_wells: int, size_z: float = 14.0) -> Plate: PT1536F = _make_plate("test_1536_flange", 1536, size_z=10.0) -class EL406TestCase(unittest.IsolatedAsyncioTestCase): - """Base test case with mock FTDI IO and patched asyncio.sleep.""" +class EL406TestCase(AnyioTestBase): + """Base test case with mock FTDI IO and patched anyio.sleep.""" - async def asyncSetUp(self): - self._sleep_patcher = patch("asyncio.sleep", side_effect=_noop) + async def _enter_lifespan(self, stack): + self._sleep_patcher = patch("anyio.sleep", side_effect=_noop) self._sleep_patcher.start() + stack.callback(self._sleep_patcher.stop) + self.backend = ExperimentalBioTekEL406Backend() self.backend.io = MockFTDI() - await self.backend.setup() + self.backend.io.set_read_buffer(b"\x06" * 500) - async def asyncTearDown(self): - if self.backend.io is not None: - self.backend.io.set_read_buffer(b"\x06" * 500) - await self.backend.stop() - self._sleep_patcher.stop() + await stack.enter_async_context(self.backend) + + def _pre_cleanup(): + if self.backend.io is not None: + self.backend.io.set_read_buffer(b"\x06" * 500) + + stack.callback(_pre_cleanup) class MockFTDI: @@ -84,10 +89,10 @@ def _default_response_buffer() -> bytes: single_response = b"\x06" + header return single_response * 200 - async def setup(self): - pass + async def __aenter__(self): + return self - async def stop(self): + async def __aexit__(self, exc_type, exc_val, exc_tb): pass async def write(self, data: bytes) -> int: diff --git a/pylabrobot/plate_washing/biotek/el406/queries_tests.py b/pylabrobot/plate_washing/biotek/el406/queries_tests.py index 4891605952b..ac6c3bc2814 100644 --- a/pylabrobot/plate_washing/biotek/el406/queries_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/queries_tests.py @@ -4,8 +4,6 @@ This module contains tests for Query methods. """ -import unittest - # Import the backend module from pylabrobot.plate_washing.biotek.el406 import ( EL406Sensor, @@ -454,8 +452,8 @@ def _build_multi_query_buffer(self): buf += MockFTDI.build_completion_frame(bytes([0x01, 0x00, 0x00])) return buf - async def asyncSetUp(self): - await super().asyncSetUp() + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.backend.io.read_buffer = self._build_multi_query_buffer() async def test_request_instrument_settings_returns_dict(self): @@ -488,7 +486,3 @@ async def test_request_instrument_settings_raises_when_device_not_initialized(se backend = ExperimentalBioTekEL406Backend() with self.assertRaises(RuntimeError): await backend.request_instrument_settings() - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/setup_tests.py b/pylabrobot/plate_washing/biotek/el406/setup_tests.py index e0cd1d0a758..d2cd4ecdf91 100644 --- a/pylabrobot/plate_washing/biotek/el406/setup_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/setup_tests.py @@ -17,18 +17,15 @@ async def test_setup_creates_io(self): """Setup should create and configure FTDI IO wrapper.""" backend = ExperimentalBioTekEL406Backend(timeout=0.01) backend.io = MockFTDI() - await backend.setup() - - self.assertIsNotNone(backend.io) + async with backend: + self.assertIsNotNone(backend.io) async def test_stop_closes_device(self): """Stop should close the FTDI device.""" backend = ExperimentalBioTekEL406Backend(timeout=0.01) backend.io = MockFTDI() - await backend.setup() - - self.assertIsNotNone(backend.io) - await backend.stop() + async with backend: + self.assertIsNotNone(backend.io) self.assertIsNone(backend.io) diff --git a/pylabrobot/plate_washing/biotek/el406/steps_aspirate_tests.py b/pylabrobot/plate_washing/biotek/el406/steps_aspirate_tests.py index 22ef6d92c12..e63b5a6c1d6 100644 --- a/pylabrobot/plate_washing/biotek/el406/steps_aspirate_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/steps_aspirate_tests.py @@ -187,7 +187,3 @@ def test_aspirate_command_length(self): """Aspirate command should be exactly 22 bytes.""" cmd = self.backend._build_aspirate_command(PT96) self.assertEqual(len(cmd), 22) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/steps_dispense_tests.py b/pylabrobot/plate_washing/biotek/el406/steps_dispense_tests.py index 917ab1c8a79..ce907a025ba 100644 --- a/pylabrobot/plate_washing/biotek/el406/steps_dispense_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/steps_dispense_tests.py @@ -1,8 +1,6 @@ # mypy: disable-error-code="union-attr,assignment,arg-type" """Tests for BioTek EL406 plate washer backend - Dispense operations.""" -import unittest - from pylabrobot.plate_washing.biotek.el406 import ExperimentalBioTekEL406Backend from pylabrobot.plate_washing.biotek.el406.mock_tests import PT96, EL406TestCase @@ -241,7 +239,3 @@ async def test_syringe_dispense_raises_on_timeout(self): self.backend.io.set_read_buffer(b"") # No ACK response with self.assertRaises(TimeoutError): await self.backend.syringe_dispense(PT96, volume=50.0, syringe="A") - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/steps_peristaltic_tests.py b/pylabrobot/plate_washing/biotek/el406/steps_peristaltic_tests.py index 95ce8c450cb..f9a8200c998 100644 --- a/pylabrobot/plate_washing/biotek/el406/steps_peristaltic_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/steps_peristaltic_tests.py @@ -556,7 +556,3 @@ async def test_peristaltic_purge_raises_on_timeout(self): self.backend.io.set_read_buffer(b"") # No ACK response with self.assertRaises(TimeoutError): await self.backend.peristaltic_purge(PT96, volume=1000.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/steps_prime_tests.py b/pylabrobot/plate_washing/biotek/el406/steps_prime_tests.py index 14c51cc4370..aadf988d770 100644 --- a/pylabrobot/plate_washing/biotek/el406/steps_prime_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/steps_prime_tests.py @@ -727,7 +727,3 @@ def test_auto_clean_default_duration(self): self.assertEqual(cmd[2], 0x01) self.assertEqual(cmd[3], 0x00) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/steps_shake_tests.py b/pylabrobot/plate_washing/biotek/el406/steps_shake_tests.py index d875586f6cc..c52e7aff15e 100644 --- a/pylabrobot/plate_washing/biotek/el406/steps_shake_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/steps_shake_tests.py @@ -244,7 +244,3 @@ def test_shake_command_max_duration_encoding(self): expected = bytes.fromhex("04010f0e03000f0e00000000") self.assertEqual(cmd, expected) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/plate_washing/biotek/el406/steps_wash_tests.py b/pylabrobot/plate_washing/biotek/el406/steps_wash_tests.py index 641a988c1bb..5d2113be648 100644 --- a/pylabrobot/plate_washing/biotek/el406/steps_wash_tests.py +++ b/pylabrobot/plate_washing/biotek/el406/steps_wash_tests.py @@ -920,7 +920,3 @@ def test_all_plate_types_produce_102_bytes(self): cmd = backend._build_wash_composite_command(plate) self.assertEqual(len(cmd), 102, f"Wrong length for {plate.name}") self.assertEqual(cmd[0], expected_prefixes[plate.name], f"Wrong prefix for {plate.name}") - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/powder_dispensing/backend.py b/pylabrobot/powder_dispensing/backend.py index 161bc627318..99595aa4311 100644 --- a/pylabrobot/powder_dispensing/backend.py +++ b/pylabrobot/powder_dispensing/backend.py @@ -12,14 +12,6 @@ class PowderDispenserBackend(MachineBackend, metaclass=ABCMeta): An abstract class for a powder dispenser backend. """ - @abstractmethod - async def setup(self) -> None: - """Set up the powder dispenser.""" - - @abstractmethod - async def stop(self) -> None: - """Close all connections to the powder dispenser and make sure setup() can be called again.""" - @abstractmethod async def dispense( self, dispense_parameters: List[PowderDispense], **backend_kwargs diff --git a/pylabrobot/powder_dispensing/chatterbox.py b/pylabrobot/powder_dispensing/chatterbox.py index 6d2d5fa5640..f74c78a8631 100644 --- a/pylabrobot/powder_dispensing/chatterbox.py +++ b/pylabrobot/powder_dispensing/chatterbox.py @@ -1,5 +1,6 @@ from typing import List +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.powder_dispensing.backend import ( DispenseResults, PowderDispense, @@ -10,11 +11,10 @@ class PowderDispenserChatterboxBackend(PowderDispenserBackend): """Chatter box backend for device-free testing. Prints out all operations.""" - async def setup(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up the powder dispenser.") - - async def stop(self) -> None: - print("Stopping the powder dispenser.") + stack.callback(lambda: print("Stopping the powder dispenser.")) async def dispense( self, dispense_parameters: List[PowderDispense], **backend_kwargs diff --git a/pylabrobot/powder_dispensing/chemspeed/crystal_powderdose.py b/pylabrobot/powder_dispensing/chemspeed/crystal_powderdose.py index 1f2b6a22159..a791c32c27f 100644 --- a/pylabrobot/powder_dispensing/chemspeed/crystal_powderdose.py +++ b/pylabrobot/powder_dispensing/chemspeed/crystal_powderdose.py @@ -1,3 +1,4 @@ +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.powder_dispensing.backend import ( PowderDispenserBackend, ) @@ -9,10 +10,8 @@ class CrystalPowderdose(PowderDispenserBackend): def __init__(self, arksuite_address: str) -> None: self.arksuite_address = arksuite_address - async def setup(self) -> None: - raise NotImplementedError("CrystalPowderdose not implemented yet") - - async def stop(self) -> None: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) raise NotImplementedError("CrystalPowderdose not implemented yet") def serialize(self) -> dict: diff --git a/pylabrobot/powder_dispensing/powder_dispenser_tests.py b/pylabrobot/powder_dispensing/powder_dispenser_tests.py index 3e9b410a808..5ad8390d6d9 100644 --- a/pylabrobot/powder_dispensing/powder_dispenser_tests.py +++ b/pylabrobot/powder_dispensing/powder_dispenser_tests.py @@ -1,4 +1,3 @@ -import unittest from typing import List from unittest.mock import AsyncMock @@ -11,17 +10,12 @@ PowderDispenser, ) from pylabrobot.resources import Cor_96_wellplate_360ul_Fb, Powder +from pylabrobot.testing.concurrency import AnyioTestBase class MockPowderDispenserBackend(PowderDispenserBackend): """A mock backend for testing.""" - async def setup(self) -> None: - pass - - async def stop(self) -> None: - pass - async def dispense( self, dispense_parameters: List[PowderDispense], @@ -35,15 +29,16 @@ async def dispense( return results -class TestPowderDispenser(unittest.IsolatedAsyncioTestCase): +class TestPowderDispenser(AnyioTestBase): """ Test class for PowderDispenser. """ - async def asyncSetUp(self) -> None: + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.backend = AsyncMock(spec=MockPowderDispenserBackend) self.dispenser = PowderDispenser(backend=self.backend) - await self.dispenser.setup() + await stack.enter_async_context(self.dispenser) async def test_dispense_single_resource(self): plate = Cor_96_wellplate_360ul_Fb(name="test_resource") @@ -85,7 +80,3 @@ async def test_assertion_for_mismatched_lengths(self): [0.005, 0.010], dispense_parameters=[{"param": "value"}, {"param": "value"}], ) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/pumps/agrowpumps/agrowdosepump_backend.py b/pylabrobot/pumps/agrowpumps/agrowdosepump_backend.py index 1ba4da80908..486b3e85b28 100644 --- a/pylabrobot/pumps/agrowpumps/agrowdosepump_backend.py +++ b/pylabrobot/pumps/agrowpumps/agrowdosepump_backend.py @@ -1,9 +1,10 @@ -import asyncio import logging -import threading -import time from typing import Dict, List, Optional, Union +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding + try: from pymodbus.client import AsyncModbusSerialClient # type: ignore @@ -45,11 +46,9 @@ def __init__(self, port: str, address: Union[int, str]): if address not in range(0, 256): raise ValueError("Pump address out of range") self.address = int(address) - self._keep_alive_thread: Optional[threading.Thread] = None self._pump_index_to_address: Optional[Dict[int, int]] = None self._modbus: Optional["AsyncModbusSerialClient"] = None self._num_channels: Optional[int] = None - self._keep_alive_thread_active = False @property def modbus(self) -> "AsyncModbusSerialClient": @@ -81,66 +80,51 @@ def num_channels(self) -> int: raise RuntimeError("Number of channels not established") return self._num_channels - def start_keep_alive_thread(self): - """Creates a daemon thread that sends a Modbus request every 25 seconds to keep the connection - alive.""" - - async def keep_alive(): - """Sends a Modbus request every 25 seconds to keep the connection alive. - Sleep for 0.1 seconds so we can respond to `stop` events fast. - """ - i = 0 - while self._keep_alive_thread_active: - time.sleep(0.1) - i += 1 - if i == 250: - await self.modbus.read_holding_registers(0, 1, unit=self.address) - i = 0 - - def manage_async_keep_alive(): - """Manages the keep alive thread.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(keep_alive()) - loop.close() - except Exception as e: - logger.error("Error in keep alive thread: %s", e) - - self._keep_alive_thread_active = True - self._keep_alive_thread = threading.Thread(target=manage_async_keep_alive, daemon=True) - self._keep_alive_thread.start() - - async def setup(self): + async def _keep_alive_task(self): + """Sends a Modbus request every 25 seconds to keep the connection alive.""" + while True: + await anyio.sleep(25) + # do a keep-alive + assert self._modbus is not None + await self._modbus.read_holding_registers(0, 1, unit=self.address) + + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """Sets up the Modbus connection to the AgrowPumpArray and creates the pump mappings needed to issue commands. """ - await self._setup_modbus() - register_return = await self.modbus.read_holding_registers(19, 2, unit=self.address) - self._num_channels = int( - "".join(chr(r // 256) + chr(r % 256) for r in register_return.registers)[2] - ) - self.start_keep_alive_thread() - self._pump_index_to_address = {pump: pump + 100 for pump in range(0, self.num_channels)} - - async def _setup_modbus(self): - if AsyncModbusSerialClient is None: - raise RuntimeError( - "pymodbus is not installed. Install with: pip install pylabrobot[modbus]." - f" Import error: {_MODBUS_IMPORT_ERROR}" + if self._modbus is None: + if AsyncModbusSerialClient is None: + raise RuntimeError( + "pymodbus is not installed. Install with: pip install pylabrobot[modbus]." + f" Import error: {_MODBUS_IMPORT_ERROR}" + ) + self._modbus = AsyncModbusSerialClient( + port=self.port, + baudrate=115200, + timeout=1, + stopbits=1, + bytesize=8, + parity="E", + retry_on_empty=True, ) - self._modbus = AsyncModbusSerialClient( - port=self.port, - baudrate=115200, - timeout=1, - stopbits=1, - bytesize=8, - parity="E", - retry_on_empty=True, - ) await self.modbus.connect() if not self.modbus.connected: raise ConnectionError("Modbus connection failed during pump setup") + stack.callback(self._modbus.close) + + register_return = await self._modbus.read_holding_registers(19, 2, unit=self.address) + self._num_channels = int( + "".join(chr(r // 256) + chr(r % 256) for r in register_return.registers)[2] + ) + + tg = await stack.enter_async_context(anyio.create_task_group()) + stack.callback(tg.cancel_scope.cancel) + + tg.start_soon(self._keep_alive_task) + + stack.push_shielded_async_callback(self.halt) + + self._pump_index_to_address = {pump: pump + 100 for pump in range(0, self.num_channels)} def serialize(self): return { @@ -195,16 +179,6 @@ async def halt(self): address = self.pump_index_to_address[pump] await self.modbus.write_register(address, 0, unit=self.address) - async def stop(self): - """Close the connection to the pump array.""" - await self.halt() - assert self.modbus is not None, "Modbus connection not established" - if self._keep_alive_thread is not None: - self._keep_alive_thread_active = False - self._keep_alive_thread.join() - self.modbus.close() - assert not self.modbus.connected, "Modbus failing to disconnect" - # Deprecated alias with warning # TODO: remove mid May 2025 (giving people 1 month to update) # https://github.com/PyLabRobot/pylabrobot/issues/466 diff --git a/pylabrobot/pumps/agrowpumps/agrowdosepump_tests.py b/pylabrobot/pumps/agrowpumps/agrowdosepump_tests.py index b9d6047a0e4..d748aa4924e 100644 --- a/pylabrobot/pumps/agrowpumps/agrowdosepump_tests.py +++ b/pylabrobot/pumps/agrowpumps/agrowdosepump_tests.py @@ -1,8 +1,9 @@ -import unittest from unittest.mock import AsyncMock, call import pytest +from pylabrobot.testing.concurrency import AnyioTestBase + pytest.importorskip("pymodbus") from pymodbus.client import AsyncModbusSerialClient # type: ignore @@ -41,26 +42,20 @@ async def read_holding_registers(self, address: int, count: int, **kwargs): # t write_register = AsyncMock() def close(self, reconnect=False): - assert not self.connected, "Modbus connection not established" + assert self.connected, "Modbus connection not established" + self._connected = False -class TestAgrowPumps(unittest.IsolatedAsyncioTestCase): +class TestAgrowPumps(AnyioTestBase): """TestAgrowPumps allows users to test AgrowPumps.""" - async def asyncSetUp(self): + async def _enter_lifespan(self, stack): self.agrow_backend = AgrowPumpArrayBackend(port="simulated", address=1) - - async def _mock_setup_modbus(): - self.agrow_backend._modbus = SimulatedModbusClient() - - self.agrow_backend._setup_modbus = _mock_setup_modbus # type: ignore[method-assign] + self.agrow_backend._modbus = SimulatedModbusClient(connected=False) self.pump_array = PumpArray(backend=self.agrow_backend, calibration=None) - await self.pump_array.setup() - - async def asyncTearDown(self): - await self.pump_array.stop() + await stack.enter_async_context(self.pump_array) async def test_setup(self): self.assertEqual(self.agrow_backend.port, "simulated") diff --git a/pylabrobot/pumps/backend.py b/pylabrobot/pumps/backend.py index 3ae2a3a151d..ca04d92a54a 100644 --- a/pylabrobot/pumps/backend.py +++ b/pylabrobot/pumps/backend.py @@ -26,9 +26,6 @@ def run_continuously(self, speed: float): def halt(self): """Halt the pump.""" - async def stop(self): - """Close the connection to the pump.""" - class PumpArrayBackend(MachineBackend, metaclass=ABCMeta): """ @@ -61,6 +58,3 @@ async def run_continuously(self, speed: List[float], use_channels: List[int]): async def halt(self): """Halt the entire pump array.""" - - async def stop(self): - """Close the connection to the pump array.""" diff --git a/pylabrobot/pumps/chatterbox.py b/pylabrobot/pumps/chatterbox.py index 793792460d2..e23d4129aea 100644 --- a/pylabrobot/pumps/chatterbox.py +++ b/pylabrobot/pumps/chatterbox.py @@ -1,3 +1,4 @@ +import contextlib from typing import List from pylabrobot.pumps.backend import PumpArrayBackend, PumpBackend @@ -6,11 +7,9 @@ class PumpChatterboxBackend(PumpBackend): """Chatter box backend for device-free testing. Prints out all operations.""" - async def setup(self): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): print("Setting up the pump.") - - async def stop(self): - print("Stopping the pump.") + stack.callback(lambda: print("Stopping the pump.")) def run_revolutions(self, num_revolutions: float): print(f"Running {num_revolutions} revolutions.") @@ -28,11 +27,9 @@ class PumpArrayChatterboxBackend(PumpArrayBackend): def __init__(self, num_channels: int = 8) -> None: self._num_channels = num_channels - async def setup(self): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): print("Setting up the pump array.") - - async def stop(self): - print("Stopping the pump array.") + stack.callback(lambda: print("Stopping the pump array.")) @property def num_channels(self) -> int: diff --git a/pylabrobot/pumps/cole_parmer/masterflex_backend.py b/pylabrobot/pumps/cole_parmer/masterflex_backend.py index 41cdf5b24a0..4676af80c30 100644 --- a/pylabrobot/pumps/cole_parmer/masterflex_backend.py +++ b/pylabrobot/pumps/cole_parmer/masterflex_backend.py @@ -6,6 +6,7 @@ HAS_SERIAL = False _SERIAL_IMPORT_ERROR = e +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.pumps.backend import PumpBackend @@ -46,18 +47,15 @@ def __init__(self, com_port: str): human_readable_device_name="Masterflex Pump", ) - async def setup(self): - await self.io.setup() - + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) await self.io.write(b"\x05") # Enquiry; ready to send. await self.io.write(b"\x05P02\r") def serialize(self): return {**super().serialize(), "com_port": self.com_port} - async def stop(self): - await self.io.stop() - async def send_command(self, command: str): command = "\x02P02" + command + "\x0d" await self.io.write(command.encode()) diff --git a/pylabrobot/pumps/pump.py b/pylabrobot/pumps/pump.py index 9d92f7257ba..01e986ba519 100644 --- a/pylabrobot/pumps/pump.py +++ b/pylabrobot/pumps/pump.py @@ -1,6 +1,7 @@ -import asyncio from typing import Optional, Union +import anyio + from pylabrobot.machines.machine import Machine from .backend import PumpBackend @@ -71,7 +72,7 @@ async def run_for_duration(self, speed: Union[float, int], duration: Union[float if duration < 0: raise ValueError("Duration must be positive.") await self.run_continuously(speed=speed) - await asyncio.sleep(duration) + await anyio.sleep(duration) await self.run_continuously(speed=0) async def pump_volume(self, speed: Union[float, int], volume: Union[float, int]): diff --git a/pylabrobot/pumps/pump_tests.py b/pylabrobot/pumps/pump_tests.py index cc8c3ab7101..d31839c5c27 100644 --- a/pylabrobot/pumps/pump_tests.py +++ b/pylabrobot/pumps/pump_tests.py @@ -1,4 +1,3 @@ -import unittest from unittest.mock import AsyncMock, Mock from pylabrobot.pumps import PumpArray @@ -6,16 +5,19 @@ from pylabrobot.pumps.calibration import PumpCalibration from pylabrobot.pumps.errors import NotCalibratedError from pylabrobot.pumps.pump import Pump +from pylabrobot.testing.concurrency import AnyioTestBase -class TestPump(unittest.IsolatedAsyncioTestCase): +class TestPump(AnyioTestBase): """Tests for the Pump class. Currently, only the Cole Palmer Masterflex pump is implemented. """ - def setUp(self): + async def _enter_lifespan(self, stack): self.mock_backend = Mock(spec=PumpBackend) + self.mock_backend.__aenter__ = AsyncMock(return_value=self.mock_backend) + self.mock_backend.__aexit__ = AsyncMock(return_value=None) self.test_calibration = PumpCalibration.load_calibration(1, num_items=1) async def test_setup(self): @@ -30,22 +32,16 @@ async def test_run_revolutions(self): await pump.run_revolutions(num_revolutions=1) -class TestPumpArray(unittest.IsolatedAsyncioTestCase): +class TestPumpArray(AnyioTestBase): """Tests for the AgrowPumpArrayTester class.""" - def setUp(self): - self.mock_backend = Mock(spec=PumpArrayBackend) + async def _enter_lifespan(self, stack): + self.mock_backend = AsyncMock(spec=PumpArrayBackend) self.mock_backend.num_channels = 6 self.test_calibration = PumpCalibration.load_calibration(1, num_items=6) - async def asyncSetUp(self) -> None: - await super().asyncSetUp() self.pump_array = PumpArray(backend=self.mock_backend, calibration=None) - await self.pump_array.setup() - - async def asyncTearDown(self) -> None: - await self.pump_array.stop() - await super().asyncTearDown() + await stack.enter_async_context(self.pump_array) async def test_setup(self): """Test that the AgrowPumpArrayTester class can be initialized.""" @@ -124,7 +120,3 @@ async def test_invalid_volume(self): self.pump_array.calibration = self.test_calibration with self.assertRaises(ValueError): await self.pump_array.pump_volume(speed=1, use_channels=[0], volume=-1) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/pumps/pumparray.py b/pylabrobot/pumps/pumparray.py index daedb626a46..f8f65bb0316 100644 --- a/pylabrobot/pumps/pumparray.py +++ b/pylabrobot/pumps/pumparray.py @@ -1,6 +1,8 @@ -import asyncio +import functools from typing import List, Optional, Union +import anyio + from pylabrobot.machines.machine import Machine from pylabrobot.pumps.backend import PumpArrayBackend from pylabrobot.pumps.calibration import PumpCalibration @@ -127,7 +129,7 @@ async def run_for_duration( if duration < 0: raise ValueError("Duration must be positive.") await self.run_continuously(speed=speed, use_channels=use_channels) - await asyncio.sleep(duration) + await anyio.sleep(duration) await self.run_continuously(speed=0, use_channels=use_channels) async def pump_volume( @@ -163,35 +165,38 @@ async def pump_volume( raise ValueError("Volume must be positive.") if not len(speed) == len(use_channels) == len(volume): raise ValueError("Speed, use_channels, and volume must be the same length.") + if self.calibration.calibration_mode == "duration": durations = [ channel_volume / self.calibration[channel] for channel, channel_volume in zip(use_channels, volume) ] - tasks = [ - asyncio.create_task( - self.run_for_duration( - speed=channel_speed, - use_channels=channel, - duration=duration, + async with anyio.create_task_group() as tg: + for channel_speed, channel, duration in zip(speed, use_channels, durations): + tg.start_soon( + functools.partial( + self.run_for_duration, + speed=channel_speed, + use_channels=channel, + duration=duration, + ) ) - ) - for channel_speed, channel, duration in zip(speed, use_channels, durations) - ] elif self.calibration.calibration_mode == "revolutions": num_rotations = [ channel_volume / self.calibration[channel] for channel, channel_volume in zip(use_channels, volume) ] - tasks = [ - asyncio.create_task( - self.run_revolutions(num_revolutions=num_rotation, use_channels=channel) - ) - for num_rotation, channel in zip(num_rotations, use_channels) - ] + async with anyio.create_task_group() as tg: + for num_rotation, channel in zip(num_rotations, use_channels): + tg.start_soon( + functools.partial( + self.run_revolutions, + num_revolutions=num_rotation, + use_channels=channel, + ) + ) else: raise ValueError("Calibration mode must be 'duration' or 'revolutions'.") - await asyncio.gather(*tasks) async def halt(self): """Halt the entire pump array.""" diff --git a/pylabrobot/resources/resource_holder_tests.py b/pylabrobot/resources/resource_holder_tests.py index 0efcacae13c..3428fe0ece7 100644 --- a/pylabrobot/resources/resource_holder_tests.py +++ b/pylabrobot/resources/resource_holder_tests.py @@ -29,7 +29,3 @@ def test_unassign_with_none(self): def test_assign_none_when_empty(self): self.holder.resource = None self.assertIsNone(self.holder.resource) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/scales/chatterbox.py b/pylabrobot/scales/chatterbox.py index 9ffb5a68e05..de4cfab2ed3 100644 --- a/pylabrobot/scales/chatterbox.py +++ b/pylabrobot/scales/chatterbox.py @@ -1,3 +1,5 @@ +import contextlib + from pylabrobot.scales.scale_backend import ScaleBackend @@ -6,12 +8,15 @@ class ScaleChatterboxBackend(ScaleBackend): def __init__(self, dummy_weight: float = 0.0) -> None: self._dummy_weight = dummy_weight + super().__init__() - async def setup(self) -> None: + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): print("Setting up the scale.") - async def stop(self) -> None: - print("Stopping the scale.") + def _cleanup(): + print("Stopping the scale.") + + stack.callback(_cleanup) async def tare(self): print("Taring the scale") diff --git a/pylabrobot/scales/mettler_toledo_backend.py b/pylabrobot/scales/mettler_toledo_backend.py index ae73eb114df..adee940403e 100644 --- a/pylabrobot/scales/mettler_toledo_backend.py +++ b/pylabrobot/scales/mettler_toledo_backend.py @@ -1,11 +1,12 @@ # similar library: https://github.com/janelia-pypi/mettler_toledo_device_python -import asyncio import logging -import time import warnings from typing import List, Literal, Optional, Union +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.scales.scale_backend import ScaleBackend @@ -172,9 +173,9 @@ def __init__(self, port: Optional[str] = None, vid: int = 0x0403, pid: int = 0x6 timeout=1, ) - async def setup(self) -> None: - # Core state - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) # set output unit to grams await self.send_command("M21 0 0") @@ -183,9 +184,6 @@ async def setup(self) -> None: self.serial_number = await self.request_serial_number() # TODO: verify serial number pattern - async def stop(self) -> None: - await self.io.stop() - def serialize(self) -> dict: return {**super().serialize(), "port": self.io.port} @@ -266,14 +264,12 @@ async def send_command(self, command: str, timeout: int = 60) -> MettlerToledoRe await self.io.write(command.encode() + b"\r\n") raw_response = b"" - timeout_time = time.time() + timeout - while True: - raw_response = await self.io.readline() - await asyncio.sleep(0.001) - if time.time() > timeout_time: - raise TimeoutError("Timeout while waiting for response from scale.") - if raw_response != b"": - break + with anyio.fail_after(timeout): + while True: + raw_response = await self.io.readline() + if raw_response != b"": + break + await anyio.sleep(0.001) logger.debug("[scale] Received response: %s", raw_response) response = raw_response.decode("utf-8").strip().split() diff --git a/pylabrobot/sealing/a4s_backend.py b/pylabrobot/sealing/a4s_backend.py index 5c07cb4a333..fe3d9e382b3 100644 --- a/pylabrobot/sealing/a4s_backend.py +++ b/pylabrobot/sealing/a4s_backend.py @@ -1,9 +1,9 @@ -import asyncio import dataclasses import enum -import time from typing import Set +import anyio + try: import serial @@ -12,6 +12,7 @@ HAS_SERIAL = False _SERIAL_IMPORT_ERROR = e +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.sealing.backend import SealerBackend @@ -35,13 +36,16 @@ def __init__(self, port: str, timeout=20) -> None: human_readable_device_name="A4S Sealer", ) - async def setup(self): - await self.io.setup() - await self.system_reset() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) + + async def cleanup(): + await self.set_heater(on=False) - async def stop(self): - await self.set_heater(on=False) - await self.io.stop() + stack.push_shielded_async_callback(cleanup) + + await self.system_reset() async def set_heater(self, on: bool): """Set the heater on or off.""" @@ -86,20 +90,21 @@ class SensorStatus: async def _read_message(self) -> str: """read a message. we are not sure what format it is.""" - start = time.time() r, x = b"", b"" has_read_r = False - while x != b"" or (len(r) == 0 and x == b""): - x = await self.io.read() - if has_read_r: - r += x - if x == b"\r": - if not has_read_r: - has_read_r = True - else: - break - if time.time() - start > self.timeout: - raise TimeoutError("Timeout while waiting for response") + try: + with anyio.fail_after(self.timeout): + while x != b"" or (len(r) == 0 and x == b""): + x = await self.io.read() + if has_read_r: + r += x + if x == b"\r": + if not has_read_r: + has_read_r = True + else: + break + except TimeoutError: + raise TimeoutError(f"Timeout reading message after {self.timeout} seconds") from None return r.decode("utf-8") async def get_status(self) -> Status: @@ -144,26 +149,28 @@ async def get_status(self) -> Status: ) async def _wait_for_status(self, statuses: Set["A4SBackend.Status.SystemStatus"]) -> Status: - start = time.time() - while True: - status = await self.get_status() - - if status.system_status == A4SBackend.Status.SystemStatus.error: - raise RuntimeError(f"An error occurred: {status.error_code}") + try: + with anyio.fail_after(self.timeout): + while True: + status = await self.get_status() - if status.system_status in statuses: - return status + if status.system_status == A4SBackend.Status.SystemStatus.error: + raise RuntimeError(f"An error occurred: {status.error_code}") - if time.time() - start > self.timeout: - raise TimeoutError("Timeout while waiting for response") + if status.system_status in statuses: + return status - await asyncio.sleep(0.01) + await anyio.sleep(0.01) + except TimeoutError: + raise TimeoutError( + f"Timeout waiting for status {statuses} after {self.timeout} seconds" + ) from None async def send_command(self, command: str): # command accepted: *Y01PL! # Command index: 01 await self.io.write(command.encode()) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) async def seal(self, temperature: int, duration: float): await self.set_temperature(temperature) @@ -175,25 +182,30 @@ async def seal(self, temperature: int, duration: float): ) async def _wait_for_temperature(self, degrees: float, timeout: float, tolerance: float = 0.5): - start = time.time() - while True: - current_temperature = await self.get_temperature() - if abs(current_temperature - degrees) < tolerance: - break - if time.time() - start > timeout: - raise TimeoutError("Timeout while waiting for temperature") - await asyncio.sleep(0.1) + try: + with anyio.fail_after(timeout): + while True: + current_temperature = await self.get_temperature() + if abs(current_temperature - degrees) < tolerance: + break + await anyio.sleep(0.1) + except TimeoutError: + raise TimeoutError(f"Temperature did not reach target within {timeout} seconds") from None async def _wait_for_shuttle_open_sensor( self, shuttle_open: bool, timeout: float = 30.0 ) -> Status: - start = time.time() - while True: - status = await self.get_status() - if status.sensor_status.shuttle_open_sensor == shuttle_open: - return status - if time.time() - start > timeout: - raise TimeoutError("Timeout while waiting for shuttle open sensor") + try: + with anyio.fail_after(timeout): + while True: + status = await self.get_status() + if status.sensor_status.shuttle_open_sensor == shuttle_open: + return status + await anyio.sleep(0.1) + except TimeoutError: + raise TimeoutError( + f"Timeout waiting for shuttle open sensor to be {shuttle_open} after {timeout} seconds" + ) from None async def set_temperature(self, temperature: float): if not (50 <= temperature <= 200): diff --git a/pylabrobot/shaking/chatterbox.py b/pylabrobot/shaking/chatterbox.py index 8fcfc2933f7..5cc64704dd4 100644 --- a/pylabrobot/shaking/chatterbox.py +++ b/pylabrobot/shaking/chatterbox.py @@ -1,3 +1,4 @@ +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.shaking import ShakerBackend @@ -6,11 +7,10 @@ class ShakerChatterboxBackend(ShakerBackend): temperature: float = 0 - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up shaker") - - async def stop(self): - print("Stopping shaker") + stack.callback(lambda: print("Stopping shaker")) async def start_shaking(self, speed: float): print("Shaking at speed", speed) diff --git a/pylabrobot/shaking/shaker.py b/pylabrobot/shaking/shaker.py index a503279e5d4..98ccb405b1d 100644 --- a/pylabrobot/shaking/shaker.py +++ b/pylabrobot/shaking/shaker.py @@ -1,6 +1,7 @@ -import asyncio from typing import Optional +import anyio + from pylabrobot.machines.machine import Machine from pylabrobot.resources import Coordinate, ResourceHolder @@ -48,7 +49,7 @@ async def shake(self, speed: float, duration: Optional[float] = None, **backend_ if duration is None: return - await asyncio.sleep(duration) + await anyio.sleep(duration) await self.backend.stop_shaking() if self.backend.supports_locking: await self.backend.unlock_plate() diff --git a/pylabrobot/storage/chatterbox.py b/pylabrobot/storage/chatterbox.py index 89115098f00..97f0bed5d7f 100644 --- a/pylabrobot/storage/chatterbox.py +++ b/pylabrobot/storage/chatterbox.py @@ -1,3 +1,4 @@ +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.resources.carrier import PlateHolder from pylabrobot.resources.plate import Plate from pylabrobot.storage.backend import IncubatorBackend @@ -7,11 +8,10 @@ class IncubatorChatterboxBackend(IncubatorBackend): def __init__(self): self._dummy_temperature = 37.0 - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up incubator backend") - - async def stop(self): - print("Stopping incubator backend") + stack.callback(lambda: print("Stopping incubator backend")) async def open_door(self): print("Opening door") diff --git a/pylabrobot/storage/cytomat/cytomat.py b/pylabrobot/storage/cytomat/cytomat.py index ca9227fe003..ee724c20090 100644 --- a/pylabrobot/storage/cytomat/cytomat.py +++ b/pylabrobot/storage/cytomat/cytomat.py @@ -1,9 +1,9 @@ -import asyncio import logging -import time import warnings from typing import List, Literal, Optional, Union, cast +import anyio + try: import serial @@ -12,6 +12,7 @@ HAS_SERIAL = False _SERIAL_IMPORT_ERROR = e +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.resources import Plate, PlateCarrier, PlateHolder from pylabrobot.storage.backend import IncubatorBackend @@ -92,8 +93,9 @@ def __init__(self, model: Union[CytomatType, str], port: str): human_readable_device_name="Cytomat", ) - async def setup(self): - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) await self.initialize() await self.wait_for_task_completion() @@ -101,9 +103,6 @@ async def set_racks(self, racks: List[PlateCarrier]): await super().set_racks(racks) warnings.warn("Cytomat racks need to be configured with the exe software") - async def stop(self): - await self.io.stop() - def _assemble_command(self, command_type: str, command: str, params: str): carriage_return = "\r" if self.model == CytomatType.C2C_425 else "\r\n" command = f"{command_type}:{command} {params}".strip() + carriage_return @@ -143,17 +142,16 @@ async def _send_command(command_str) -> str: # which costs 1s if there is a true error, but is necessary to avoid false negatives. command_str = self._assemble_command(command_type=command_type, command=command, params=params) n_retries = 10 - exc: Optional[BaseException] = None - for _ in range(n_retries): + for attempt in reversed(range(n_retries)): try: return await _send_command(command_str) - except (CytomatCommandUnknownError, CytomatBusyError) as e: - exc = e - await asyncio.sleep(0.1) + except (CytomatCommandUnknownError, CytomatBusyError): + if not attempt: + await self.reset_error_register() + raise + await anyio.sleep(0.1) continue - assert exc is not None - await self.reset_error_register() - raise exc + raise RuntimeError("Internal error - this should be unreachable.") async def send_action( self, command_type: str, command: str, params: str, timeout: Optional[int] = 60 @@ -201,7 +199,7 @@ async def get_overview_register(self) -> OverviewRegisterState: try: resp = await self.send_command("ch", "bs", "") except (CytomatCommandUnknownError, CytomatBusyError): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) continue return OverviewRegisterState.from_resp(resp) await self.reset_error_register() @@ -337,7 +335,7 @@ async def action_read_barcode( async def wait_for_transfer_station(self, occupied: bool = False): """Wait for the transfer station to be occupied, or unoccupied.""" while (await self.get_overview_register()).transfer_station_occupied != occupied: - await asyncio.sleep(1) + await anyio.sleep(1) async def wait_for_task_completion(self, timeout=60) -> OverviewRegisterState: """ @@ -346,20 +344,21 @@ async def wait_for_task_completion(self, timeout=60) -> OverviewRegisterState: If the error bit is set in the overview register, the error register is read and the corresponding error is raised. """ - start = time.time() - while True: - overview_register = await self.get_overview_register() - if not overview_register.busy_bit_set: - # only check for errors once the cytomat is done, so that the user has the chance to - # handle the error and proceed if desired. - if overview_register.error_register_set: - error_register = await self.get_error_register() - await self.reset_error_register() - raise error_register_map[error_register] - return overview_register - await asyncio.sleep(1) - if time.time() - start > timeout: - raise TimeoutError("Cytomat did not complete task in time") + try: + with anyio.fail_after(timeout): + while True: + overview_register = await self.get_overview_register() + if not overview_register.busy_bit_set: + # only check for errors once the cytomat is done, so that the user has the chance to + # handle the error and proceed if desired. + if overview_register.error_register_set: + error_register = await self.get_error_register() + await self.reset_error_register() + raise error_register_map[error_register] + return overview_register + await anyio.sleep(1) + except TimeoutError: + raise TimeoutError("Cytomat did not complete task in time") from None async def init_shakers(self): return hex_to_binary(await self.send_command("ll", "vi", "")) @@ -423,11 +422,10 @@ def serialize(self) -> dict: class CytomatChatterbox(CytomatBackend): - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await IncubatorBackend._enter_lifespan(self, stack) await self.wait_for_task_completion() - - async def stop(self): - print("closing connection to cytomat") + stack.callback(lambda: print("closing connection to cytomat")) async def send_command(self, command_type, command, params): print( diff --git a/pylabrobot/storage/cytomat/heraeus_cytomat_backend.py b/pylabrobot/storage/cytomat/heraeus_cytomat_backend.py index a991d55fc59..c0a382c36ea 100644 --- a/pylabrobot/storage/cytomat/heraeus_cytomat_backend.py +++ b/pylabrobot/storage/cytomat/heraeus_cytomat_backend.py @@ -1,9 +1,9 @@ -import asyncio import logging -import time import warnings from typing import List, Tuple +import anyio + try: import serial @@ -12,6 +12,7 @@ HAS_SERIAL = False _SERIAL_IMPORT_ERROR = e +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.resources import Plate, PlateHolder from pylabrobot.resources.carrier import PlateCarrier @@ -55,7 +56,7 @@ def __init__(self, port: str): human_readable_device_name="Heraeus Cytomat", ) - async def setup(self) -> Serial: + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: """ 1. Open serial port (9600 8E1, RTS/CTS) via the Serial wrapper. 2. Send >200 ms break, wait 150 ms, flush buffers. @@ -63,45 +64,42 @@ async def setup(self) -> Serial: 4. Activate handling: ST 1801 → expect OK 5. Poll ready-flag: RD 1915 → wait for "1" """ + await super()._enter_lifespan(stack) try: - await self.io.setup() + await stack.enter_async_context(self.io) except serial.SerialException as e: raise RuntimeError(f"Could not open {self.io.port}: {e}") await self.io.send_break(duration=0.2) # >100 ms required - await asyncio.sleep(0.15) + await anyio.sleep(0.15) await self.io.reset_input_buffer() await self.io.reset_output_buffer() await self.io.write(b"CR\r") - deadline = time.time() + self.init_timeout - while time.time() < deadline: - resp = await self.io.readline() # reads through LF - if resp.strip() == b"CC": - break - else: - await self.io.stop() - raise TimeoutError(f"No CC response from PLC within {self.init_timeout} seconds") + try: + with anyio.fail_after(self.init_timeout): + while True: + resp = await self.io.readline() # reads through LF + if resp.strip() == b"CC": + break + except TimeoutError: + raise TimeoutError(f"No CC response from PLC within {self.init_timeout} seconds") from None await self.io.write(b"ST 1801\r") resp = await self.io.readline() if resp.strip() != b"OK": - await self.io.stop() raise RuntimeError(f"Unexpected reply to ST 1801: {resp!r}") - deadline = time.time() + self.start_timeout - while time.time() < deadline: - await self.io.write(b"RD 1915\r") - flag = await self.io.readline() - if flag.strip() == b"1": - return self.io - await asyncio.sleep(self.poll_interval) - - await self.io.stop() - raise TimeoutError(f"PLC did not signal ready within {self.start_timeout} seconds") - - async def stop(self): - await self.io.stop() + try: + with anyio.fail_after(self.start_timeout): + while True: + await self.io.write(b"RD 1915\r") + flag = await self.io.readline() + if flag.strip() == b"1": + return + await anyio.sleep(self.poll_interval) + except TimeoutError: + raise TimeoutError(f"PLC did not signal ready within {self.start_timeout} seconds") from None async def set_racks(self, racks: List[PlateCarrier]): await super().set_racks(racks) @@ -179,7 +177,7 @@ async def _send_command(self, command: str) -> str: async def wait_for_transfer_station(self, occupied: bool = False): while (await self.read_plate_detection_xfer()) != occupied: - await asyncio.sleep(1) + await anyio.sleep(1) async def read_plate_detection_xfer(self) -> bool: """Read Plate Detection Transfer Station (RD 1813).""" @@ -190,14 +188,15 @@ async def _wait_ready(self, timeout: int = 60): """ Poll the ready flag (RD 1915) until it becomes '1' or timeout. """ - start = time.time() - while True: - resp = await self._send_command("RD 1915") - if resp == "1": - return - await asyncio.sleep(0.1) - if time.time() - start > timeout: - raise TimeoutError("Legacy Cytomat did not become ready in time") + try: + with anyio.fail_after(timeout): + while True: + resp = await self._send_command("RD 1915") + if resp == "1": + return + await anyio.sleep(0.1) + except TimeoutError: + raise TimeoutError("Legacy Cytomat did not become ready in time") from None def serialize(self) -> dict: return { diff --git a/pylabrobot/storage/incubator.py b/pylabrobot/storage/incubator.py index 4a4d6d5fe64..11b38660eca 100644 --- a/pylabrobot/storage/incubator.py +++ b/pylabrobot/storage/incubator.py @@ -1,6 +1,7 @@ import random from typing import List, Literal, Optional, Union, cast +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.machines import Machine from pylabrobot.resources import ( Coordinate, @@ -59,8 +60,8 @@ def __init__( def racks(self) -> List[PlateCarrier]: return self._racks - async def setup(self, **backend_kwargs): - await super().setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) await self.backend.set_racks(self._racks) def get_num_free_sites(self) -> int: diff --git a/pylabrobot/storage/inheco/incubator_shaker.py b/pylabrobot/storage/inheco/incubator_shaker.py index 4b6c0f0d389..934c5c74754 100644 --- a/pylabrobot/storage/inheco/incubator_shaker.py +++ b/pylabrobot/storage/inheco/incubator_shaker.py @@ -1,5 +1,6 @@ from typing import Dict +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.machines.machine import Machine from pylabrobot.resources import Coordinate, Resource, ResourceHolder @@ -72,10 +73,8 @@ def num_units(self) -> int: "incubator_shaker_dwp": 2.5, } - async def setup(self, **backend_kwargs) -> None: - """Connect to the stack and build per-unit proxies.""" - - await self.backend.setup(**backend_kwargs) + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) self.power_credit = 0.0 @@ -121,10 +120,6 @@ async def setup(self, **backend_kwargs) -> None: f"Too many units: unit composition {self.backend.unit_composition} is exceeding 5 power credit limit. Reduce number of units." ) - async def stop(self): - """Gracefully stop backend communication.""" - await self.backend.stop() - async def request_loading_tray_states(self) -> dict: """Request loading tray states for all units.""" diff --git a/pylabrobot/storage/inheco/incubator_shaker_backend.py b/pylabrobot/storage/inheco/incubator_shaker_backend.py index 2d4a7ace352..9e4b2f25a6e 100644 --- a/pylabrobot/storage/inheco/incubator_shaker_backend.py +++ b/pylabrobot/storage/inheco/incubator_shaker_backend.py @@ -14,12 +14,14 @@ - Protocol-conformant parsing for EEPROM, sensor, and status commands. """ -import asyncio import logging import sys from functools import wraps from typing import Awaitable, Callable, Dict, List, Literal, Optional, TypeVar, cast +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.machines.machine import MachineBackend @@ -187,7 +189,7 @@ def __init__( InhecoIncubatorUnitType ] = [] # e.g. ["incubator_mp", "incubator_shaker_dwp", ...] - self._send_command_lock = asyncio.Lock() + self._send_command_lock: Optional[anyio.Lock] = None @property def number_of_connected_units(self) -> int: @@ -200,14 +202,12 @@ def __repr__(self): + f"DIP={self.dip_switch_id}) at {self.io.port}>" ) - async def setup(self, port: Optional[str] = None): - """ - Detect and connect to the Inheco machine stack. - Discover Inheco device via VID:PID (0403:6001) and verify DIP switch ID. - """ - - # --- Establish serial connection --- - await self.io.setup() + async def _enter_lifespan( + self, stack: AsyncExitStackWithShielding, *, port: Optional[str] = None + ): + await super()._enter_lifespan(stack) + self._send_command_lock = anyio.Lock() + await stack.enter_async_context(self.io) self.io.dtr = False self.io.rts = False @@ -227,17 +227,6 @@ async def setup(self, port: Optional[str] = None): f"{self.dip_switch_id}). Please verify the DIP switch setting or wiring." ) self.logger.error(msg, exc_info=e) - - # --- Fail-safe teardown --- - try: - await self.io.stop() - self.logger.debug("Closed serial connection on %s", self.io.port) - except Exception as close_err: - self.logger.warning( - "Failed to close serial port cleanly on %s: %s", - self.io._port, - close_err, - ) raise RuntimeError(msg) from e else: @@ -271,32 +260,39 @@ async def setup(self, port: Optional[str] = None): self.unit_composition, ) - async def stop(self): - """Close serial connection & stop all active units in the stack.""" + async def cleanup(): + for unit_index in range(self.number_of_connected_units): + try: + temp_status = await self.is_temperature_control_enabled(stack_index=unit_index) - for unit_index in range(self.number_of_connected_units): - temp_status = await self.is_temperature_control_enabled(stack_index=unit_index) + if temp_status: + print(f"Stopping temperature control on unit {unit_index}...") + await self.stop_temperature_control(stack_index=unit_index) + except Exception as e: + self.logger.warning(f"Failed to stop temperature control on unit {unit_index}: {e}") - if temp_status: - print(f"Stopping temperature control on unit {unit_index}...") - await self.stop_temperature_control(stack_index=unit_index) + try: + shake_status = await self.is_shaking_enabled(stack_index=unit_index) - shake_status = await self.is_shaking_enabled(stack_index=unit_index) + if shake_status: + print(f"Stopping shaking on unit {unit_index}...") + await self.stop_shaking(stack_index=unit_index) + except Exception as e: + self.logger.warning(f"Failed to stop shaking on unit {unit_index}: {e}") - if shake_status: - print(f"Stopping shaking on unit {unit_index}...") - await self.stop_shaking(stack_index=unit_index) + try: + await self.close(stack_index=unit_index) + except Exception as e: + self.logger.warning(f"Failed to close unit {unit_index}: {e}") - await self.close(stack_index=unit_index) + stack.push_shielded_async_callback(cleanup) - await self.io.stop() + # stop method removed, logic moved to cleanup via AsyncExitStack # === Low-level I/O === async def _read_full_response(self, timeout: float) -> bytes: """Read a complete Inheco response frame asynchronously.""" - loop = asyncio.get_event_loop() - start = loop.time() buf = bytearray() expected_hdr = (0xB0 + self.dip_switch_id) & 0xFF @@ -304,18 +300,16 @@ def has_complete_tail(b: bytearray) -> bool: # Valid frame ends with: [hdr][0x20-0x2F][0x60] return len(b) >= 3 and b[-1] == 0x60 and b[-3] == expected_hdr and 0x20 <= b[-2] <= 0x2F - while True: - chunk = await self.io.read(16) - if len(chunk) > 0: - buf.extend(chunk) - if has_complete_tail(buf): - self.logger.debug("RECV response: %s", buf.hex(" ")) - return bytes(buf) - - if loop.time() - start > timeout: - raise TimeoutError(f"Timed out waiting for complete response (so far: {buf.hex(' ')})") + with anyio.fail_after(timeout): + while True: + chunk = await self.io.read(16) + if len(chunk) > 0: + buf.extend(chunk) + if has_complete_tail(buf): + self.logger.debug("RECV response: %s", buf.hex(" ")) + return bytes(buf) - await asyncio.sleep(0.005) + await anyio.sleep(0.005) # === Encoding / Decoding === @@ -431,14 +425,16 @@ async def send_command( ) -> str: """Send a framed command and return parsed response or raise InhecoError.""" + assert self._send_command_lock is not None, "Lock not initialized. Enter context first." async with self._send_command_lock: # Use global default if not overridden w_timeout = write_timeout or self.write_timeout msg = self._build_message(command, stack_index=stack_index) self.logger.debug("SEND command: %s (write_timeout=%s)", msg.hex(" "), w_timeout) - await asyncio.wait_for(self.io.write(msg), timeout=w_timeout) - await asyncio.sleep(delay) + with anyio.fail_after(w_timeout): + await self.io.write(msg) + await anyio.sleep(delay) response = await self._read_full_response(timeout=read_timeout or self.read_timeout) if not response: @@ -879,7 +875,7 @@ async def wait_for_temperature( f"Temperature control is not enabled on the machine ({stack_index}: {self.unit_composition[stack_index]})." ) - start_time = asyncio.get_event_loop().time() + start_time = anyio.current_time() first_temp = await self.get_temperature(sensor=sensor, stack_index=stack_index) initial_diff = abs(first_temp - target_temp) bar_width = 40 @@ -900,7 +896,7 @@ async def wait_for_temperature( # Compute slope (°C/sec) based on direction of travel delta_done = abs(current_temp - first_temp) - elapsed = asyncio.get_event_loop().time() - start_time + elapsed = anyio.current_time() - start_time slope = delta_done / max(elapsed, 1e-6) # °C per second @@ -923,7 +919,7 @@ async def wait_for_temperature( return current_temp if timeout_s is not None: - elapsed = asyncio.get_event_loop().time() - start_time + elapsed = anyio.current_time() - start_time if elapsed > timeout_s: if show_progress_bar: sys.stdout.write("\n[ERROR] Timeout waiting for temperature.\n") @@ -935,7 +931,7 @@ async def wait_for_temperature( f"did not reach target {target_temp:.2f} °C ±{tolerance:.2f} °C." ) - await asyncio.sleep(interval_s) + await anyio.sleep(interval_s) # # # Shaking Features # # # @@ -1256,7 +1252,7 @@ async def shake( is_shaking = await self.is_shaking_enabled(stack_index=stack_index) if is_shaking: await self.stop_shaking(stack_index=stack_index) - await asyncio.sleep(0.5) # brief pause for firmware to settle + await anyio.sleep(0.5) # brief pause for firmware to settle await self.set_shaker_pattern( pattern=pattern, diff --git a/pylabrobot/storage/inheco/scila/inheco_sila_interface.py b/pylabrobot/storage/inheco/scila/inheco_sila_interface.py index 86388ae337c..494ced92d5a 100644 --- a/pylabrobot/storage/inheco/scila/inheco_sila_interface.py +++ b/pylabrobot/storage/inheco/scila/inheco_sila_interface.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import datetime import http.server import logging @@ -13,6 +12,9 @@ from dataclasses import dataclass from typing import Any, Optional, Tuple +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding, AsyncResource from pylabrobot.storage.inheco.scila.soap import ( XSI, _localname, @@ -88,7 +90,7 @@ def __init__(self, code: int, message: str, command: str, details: Optional[dict super().__init__(f"Command {command} failed with code {code}: '{message}'") -class InhecoSiLAInterface: +class InhecoSiLAInterface(AsyncResource): @dataclass(frozen=True) class _HTTPRequest: method: str @@ -97,11 +99,17 @@ class _HTTPRequest: headers: dict[str, str] body: bytes + @dataclass + class _CommandState: + result: Any = None + error: Optional[Exception] = None + @dataclass(frozen=True) class _SiLACommand: name: str request_id: int - fut: asyncio.Future[Any] + event: anyio.Event + state: InhecoSiLAInterface._CommandState def __init__( self, @@ -114,16 +122,13 @@ def __init__( self._logger = logger or logging.getLogger(__name__) # single "in-flight token" - self._making_request = asyncio.Lock() + self._making_request = anyio.Lock() # pending command information self._pending: Optional[InhecoSiLAInterface._SiLACommand] = None # server plumbing - self._loop: Optional[asyncio.AbstractEventLoop] = None self._httpd: Optional[socketserver.TCPServer] = None - self._server_task: Optional[asyncio.Task[None]] = None - self._closed = False @property def client_ip(self) -> str: @@ -139,13 +144,7 @@ def bound_port(self) -> int: raise RuntimeError("Server not started yet") return self._httpd.server_address[1] - async def start(self) -> None: - if self._httpd is not None: - return - if self._closed: - raise RuntimeError("Bridge is closed") - - self._loop = asyncio.get_running_loop() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: outer = self class _Handler(http.server.BaseHTTPRequestHandler): @@ -159,8 +158,6 @@ def _read_body(self) -> bytes: return self.rfile.read(length) if length else b"" def _do(self) -> None: - assert outer._loop is not None - parsed = urllib.parse.urlsplit(self.path) req = InhecoSiLAInterface._HTTPRequest( method=self.command, @@ -170,9 +167,8 @@ def _do(self) -> None: body=self._read_body(), ) - fut = asyncio.run_coroutine_threadsafe(outer._on_http(req), outer._loop) try: - resp_body = fut.result() + resp_body = anyio.from_thread.run(outer._on_http, req) status = 200 except Exception as e: resp_body = f"Internal Server Error: {type(e).__name__}: {e}\n".encode() @@ -201,23 +197,22 @@ def do_DELETE(self) -> None: async def run_server() -> None: assert self._httpd is not None - await asyncio.to_thread(self._httpd.serve_forever) - self._server_task = asyncio.create_task(run_server(), name="http-server") + def _serve(): + assert self._httpd is not None + with self._httpd: + self._httpd.serve_forever() - async def close(self) -> None: - self._closed = True - if self._httpd is None: - return + await anyio.to_thread.run_sync(_serve) - self._httpd.shutdown() - self._httpd.server_close() - - if self._server_task is not None: - await self._server_task + async def cleanup(): + assert self._httpd is not None + await anyio.to_thread.run_sync(self._httpd.shutdown) + self._httpd = None - self._httpd = None - self._server_task = None + tg = await stack.enter_async_context(anyio.create_task_group()) + stack.push_shielded_async_callback(cleanup) + tg.start_soon(run_server) async def _on_http(self, req: _HTTPRequest) -> bytes: """ @@ -232,21 +227,22 @@ async def _on_http(self, req: _HTTPRequest) -> bytes: payload = soap_body_payload(xml_str) tag_local = _localname(payload.tag) - if cmd is not None and not cmd.fut.done() and tag_local == "ResponseEvent": + if cmd is not None and not cmd.event.is_set() and tag_local == "ResponseEvent": response_event = soap_decode(xml_str) if response_event["ResponseEvent"].get("requestId") == cmd.request_id: ret = response_event["ResponseEvent"].get("returnValue", {}) rc = ret.get("returnCode") if rc != 3: # 3=Success - cmd.fut.set_exception( - SiLAError(rc, ret.get("message", "").replace(chr(10), " "), cmd.name, details=ret) + cmd.state.error = SiLAError( + rc, ret.get("message", "").replace(chr(10), " "), cmd.name, details=ret ) else: - cmd.fut.set_result( + cmd.state.result = ( ET.fromstring(d) if (d := response_event["ResponseEvent"].get("responseData")) else ET.Element("EmptyResponse") ) + cmd.event.set() if tag_local == "DataEvent": try: @@ -280,9 +276,6 @@ def _get_return_code_and_message(self, command_name: str, response: Any) -> Tupl raise ValueError(f"returnCode not found in response for {command_name}") return return_code, result_level.get("message", "") - async def setup(self) -> None: - await self.start() - def _make_request_id(self): return random.randint(1, 2**31 - 1) @@ -291,8 +284,8 @@ async def send_command( command: str, **kwargs, ) -> Any: - if self._closed: - raise RuntimeError("Bridge is closed") + if self._httpd is None: + raise RuntimeError("Server not started") request_id = self._make_request_id() cmd_xml = soap_encode( @@ -327,18 +320,22 @@ def _do_request() -> bytes: with urllib.request.urlopen(req) as resp: return resp.read() # type: ignore - body = await asyncio.to_thread(_do_request) + body = await anyio.to_thread.run_sync(_do_request) return_code, message = self._get_return_code_and_message( command, soap_decode(body.decode("utf-8")) ) if return_code == 1: # success return soap_decode(body.decode("utf-8")) elif return_code == 2: # concurrent command - fut: asyncio.Future[Any] = asyncio.get_running_loop().create_future() + event = anyio.Event() + state = InhecoSiLAInterface._CommandState() self._pending = InhecoSiLAInterface._SiLACommand( - name=command, request_id=request_id, fut=fut + name=command, request_id=request_id, event=event, state=state ) - return await fut # wait for response to be handled in _on_http + await event.wait() + if self._pending.state.error is not None: + raise self._pending.state.error + return self._pending.state.result else: raise RuntimeError(f"command {command} failed: {return_code} {message}") finally: diff --git a/pylabrobot/storage/inheco/scila/scila_backend.py b/pylabrobot/storage/inheco/scila/scila_backend.py index 53b94c07c02..6a15790cfda 100644 --- a/pylabrobot/storage/inheco/scila/scila_backend.py +++ b/pylabrobot/storage/inheco/scila/scila_backend.py @@ -1,6 +1,7 @@ import xml.etree.ElementTree as ET from typing import Any, Dict, Literal, Optional +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.machines.backend import MachineBackend from pylabrobot.storage.inheco.scila.inheco_sila_interface import InhecoSiLAInterface @@ -38,13 +39,11 @@ class SCILABackend(MachineBackend): def __init__(self, scila_ip: str, client_ip: Optional[str] = None) -> None: self._sila_interface = InhecoSiLAInterface(client_ip=client_ip, machine_ip=scila_ip) - async def setup(self) -> None: - await self._sila_interface.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + await stack.enter_async_context(self._sila_interface) await self._reset_and_initialize() - async def stop(self) -> None: - await self._sila_interface.close() - async def _reset_and_initialize(self) -> None: event_uri = f"http://{self._sila_interface.client_ip}:{self._sila_interface.bound_port}/" await self._sila_interface.send_command( diff --git a/pylabrobot/storage/inheco/scila/scila_backend_tests.py b/pylabrobot/storage/inheco/scila/scila_backend_tests.py index 63120a0384d..de696058950 100644 --- a/pylabrobot/storage/inheco/scila/scila_backend_tests.py +++ b/pylabrobot/storage/inheco/scila/scila_backend_tests.py @@ -20,19 +20,20 @@ def tearDown(self): self.patcher.stop() async def test_setup(self): - await self.backend.setup() - self.mock_sila_interface.setup.assert_called_once() - self.mock_sila_interface.send_command.assert_any_call( - command="Reset", - deviceId="MyController", - eventReceiverURI="http://127.0.0.1:80/", - simulationMode=False, - ) - self.mock_sila_interface.send_command.assert_any_call("Initialize") + async with self.backend: + self.mock_sila_interface.__aenter__.assert_called_once() + self.mock_sila_interface.send_command.assert_any_call( + command="Reset", + deviceId="MyController", + eventReceiverURI="http://127.0.0.1:80/", + simulationMode=False, + ) + self.mock_sila_interface.send_command.assert_any_call("Initialize") async def test_stop(self): - await self.backend.stop() - self.mock_sila_interface.close.assert_called_once() + async with self.backend: + pass + self.mock_sila_interface.__aexit__.assert_called_once() async def test_request_status(self): self.mock_sila_interface.send_command.return_value = {"GetStatusResponse": {"state": "standBy"}} @@ -234,7 +235,3 @@ def test_deserialize_no_client_ip(self): data = {"scila_ip": "169.254.1.117"} SCILABackend.deserialize(data) self.MockInhecoSiLAInterface.assert_called_with(client_ip=None, machine_ip="169.254.1.117") - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/storage/liconic/liconic_backend.py b/pylabrobot/storage/liconic/liconic_backend.py index f9b770f727f..913fbcc24dd 100644 --- a/pylabrobot/storage/liconic/liconic_backend.py +++ b/pylabrobot/storage/liconic/liconic_backend.py @@ -1,10 +1,10 @@ -import asyncio import logging import re -import time import warnings from typing import List, Optional, Tuple, Union +import anyio + try: import serial @@ -14,6 +14,7 @@ _SERIAL_IMPORT_ERROR = e from pylabrobot.barcode_scanners import BarcodeScanner +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.resources import Plate, PlateHolder from pylabrobot.resources.barcode import Barcode @@ -93,7 +94,7 @@ def __init__( self.n2_installed: Optional[bool] = None # Function to setup serial connection with Liconic PLC - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): """ 1. Open serial port (9600 8E1, RTS/CTS) via the Serial wrapper. 2. Send >200 ms break, wait 150 ms, flush buffers. @@ -101,42 +102,44 @@ async def setup(self): 4. Activate handling: ST 1801 → expect OK 5. Poll ready-flag: RD 1915 → wait for "1" """ + await super()._enter_lifespan(stack) try: - await self.io.setup() + await stack.enter_async_context(self.io) except serial.SerialException as e: raise RuntimeError(f"Could not open {self.io.port}: {e}") from e await self.io.send_break(duration=0.2) # >100 ms required - await asyncio.sleep(0.15) + await anyio.sleep(0.15) await self.io.reset_input_buffer() await self.io.reset_output_buffer() await self.io.write(b"CR\r") - deadline = time.time() + self.init_timeout - while time.time() < deadline: - resp = await self.io.readline() # reads through LF - if resp.strip() == b"CC": - break - else: - await self.io.stop() - raise TimeoutError(f"No CC response from Liconic PLC within {self.init_timeout} seconds") + try: + with anyio.fail_after(self.init_timeout): + while True: + resp = await self.io.readline() # reads through LF + if resp.strip() == b"CC": + break + except TimeoutError as e: + raise TimeoutError( + f"No CC response from Liconic PLC within {self.init_timeout} seconds" + ) from e await self.io.write(b"ST 1801\r") resp = await self.io.readline() if resp.strip() != b"OK": - await self.io.stop() raise RuntimeError(f"Unexpected reply to ST 1801: {resp!r}") - deadline = time.time() + self.start_timeout - while time.time() < deadline: - await self.io.write(b"RD 1915\r") - flag = await self.io.readline() - if flag.strip() == b"1": - break - await asyncio.sleep(self.poll_interval) - else: - await self.io.stop() - raise TimeoutError(f"PLC did not signal ready within {self.start_timeout} seconds") + try: + with anyio.fail_after(self.start_timeout): + while True: + await self.io.write(b"RD 1915\r") + flag = await self.io.readline() + if flag.strip() == b"1": + break + await anyio.sleep(self.poll_interval) + except TimeoutError as e: + raise TimeoutError(f"PLC did not signal ready within {self.start_timeout} seconds") from e def _site_to_m_n(self, site: PlateHolder) -> Tuple[int, int]: rack = site.parent @@ -166,9 +169,6 @@ def _carrier_to_steps_pos(self, site: PlateHolder) -> Tuple[int, int]: f"Could not parse site height and pos num from PlateCarrier model: {rack.model}" ) - async def stop(self): - await self.io.stop() - async def set_racks(self, racks: List[PlateCarrier]): await super().set_racks(racks) warnings.warn("Liconic racks need to be configured manually on each setup") @@ -307,36 +307,38 @@ async def _wait_plate_ready(self, timeout: int = 60): """ Poll the plate-ready flag (RD 1914) until it is set, or timeout is reached. """ - start = time.time() - deadline = start + timeout - while time.time() < deadline: - resp = await self._send_command("RD 1914") - if resp == "1": - return - await asyncio.sleep(0.1) - raise TimeoutError(f"Plate did not become ready within {timeout} seconds") + try: + with anyio.fail_after(timeout): + while True: + resp = await self._send_command("RD 1914") + if resp == "1": + return + await anyio.sleep(0.1) + except TimeoutError: + raise TimeoutError(f"Plate was not ready within {timeout} seconds") from None async def _wait_ready(self, timeout: int = 60): """ Poll the ready-flag (RD 1915) until it is set. If timeout is reached the error flag is read and if true aka "1" then the error register is read. """ - start = time.time() - deadline = start + timeout - while time.time() < deadline: - resp = await self._send_command("RD 1915") - if resp == "1": - return - await asyncio.sleep(0.1) - err_flag = await self._send_command("RD 1814") - if err_flag == "1": - error = await self._send_command("RD DM200") - for member in HandlingError: - if error == member.value: - cls, msg = handler_error_map[member] - raise cls(msg) - raise RuntimeError(f"Liconic Handler in unknown error state with memory showing {error}") - raise TimeoutError(f"Incubator did not become ready within {timeout} seconds") + try: + with anyio.fail_after(timeout): + while True: + resp = await self._send_command("RD 1915") + if resp == "1": + return + await anyio.sleep(0.1) + except TimeoutError: + err_flag = await self._send_command("RD 1814") + if err_flag == "1": + error = await self._send_command("RD DM200") + for member in HandlingError: + if error == member.value: + cls, msg = handler_error_map[member] + raise cls(msg) + raise RuntimeError(f"Liconic Handler in unknown error state with memory showing {error}") + raise TimeoutError(f"Incubator did not become ready within {timeout} seconds") async def set_temperature(self, temperature: float): """Set the temperature of the incubator in degrees Celsius. Using command WR DM890 ttttt @@ -527,7 +529,7 @@ async def check_shovel_sensor(self) -> bool: UNTESTED.""" await self._send_command("ST 1911") - await asyncio.sleep(0.1) + await anyio.sleep(0.1) resp = await self._send_command("RD 1812") if resp == "1": return True diff --git a/pylabrobot/storage/liconic/liconic_backend_tests.py b/pylabrobot/storage/liconic/liconic_backend_tests.py index b3ed09dccfa..e9b1e5a3769 100644 --- a/pylabrobot/storage/liconic/liconic_backend_tests.py +++ b/pylabrobot/storage/liconic/liconic_backend_tests.py @@ -409,7 +409,3 @@ async def test_send_command_raises_on_unknown_error(self): with self.assertRaises(RuntimeError) as ctx: await self.backend._send_command("ST 1801") self.assertIn("Unknown error", str(ctx.exception)) - - -if __name__ == "__main__": - unittest.main() diff --git a/pylabrobot/temperature_controlling/chatterbox.py b/pylabrobot/temperature_controlling/chatterbox.py index b53510080ac..5cfd445bb21 100644 --- a/pylabrobot/temperature_controlling/chatterbox.py +++ b/pylabrobot/temperature_controlling/chatterbox.py @@ -1,3 +1,4 @@ +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.temperature_controlling.backend import ( TemperatureControllerBackend, ) @@ -13,11 +14,10 @@ def supports_active_cooling(self) -> bool: def __init__(self, dummy_temperature: float = 0.0) -> None: self._dummy_temperature = dummy_temperature - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up the temperature controller.") - - async def stop(self): - print("Stopping the temperature controller.") + stack.callback(lambda: print("Stopping the temperature controller.")) async def set_temperature(self, temperature: float): print(f"Setting the temperature to {temperature}.") diff --git a/pylabrobot/temperature_controlling/inheco/control_box.py b/pylabrobot/temperature_controlling/inheco/control_box.py index d4699232797..7a65931ec42 100644 --- a/pylabrobot/temperature_controlling/inheco/control_box.py +++ b/pylabrobot/temperature_controlling/inheco/control_box.py @@ -1,6 +1,8 @@ -import time +import contextlib import typing +import anyio + from pylabrobot.io.hid import HID @@ -15,14 +17,11 @@ def __init__( human_readable_device_name="Inheco Control Box", vid=vid, pid=pid, serial_number=serial_number ) - async def setup(self): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): """ - If io.setup() fails, ensure that libusb drivers were installed as per docs. + If HID._enter_lifespan() fails, ensure that libusb drivers were installed as per docs. """ - await self.io.setup() - - async def stop(self): - await self.io.stop() + await stack.enter_async_context(self.io) @typing.no_type_check def _generate_packets(self, msg): @@ -83,21 +82,21 @@ def _crc8(self, data, crc: int) -> int: async def _read_until_end(self, timeout: int) -> str: """Read until a packet ends with a \\x00 byte. May read multiple packets.""" - start = time.time() response = b"" - while time.time() - start < timeout: - packet = await self.io.read(64, timeout=timeout) - if packet is not None and packet != b"": - if packet.endswith(b"\x00"): - response += packet.rstrip(b"\x00") # strip trailing \x00's - break - elif packet.endswith(b"#"): - response += packet[:-1] - continue - else: - # I have never seen this happen, commands always end with \x00 or '#' - print("weird packet, please report", packet) - response += packet + with anyio.fail_after(timeout): + while True: + packet = await self.io.read(64, timeout=timeout) + if packet is not None and packet != b"": + if packet.endswith(b"\x00"): + response += packet.rstrip(b"\x00") # strip trailing \x00's + break + elif packet.endswith(b"#"): + response += packet[:-1] + continue + else: + # I have never seen this happen, commands always end with \x00 or '#' + print("weird packet, please report", packet) + response += packet return response.decode("unicode_escape") @@ -109,15 +108,12 @@ async def _read_response(self, command: str, timeout: int = 60) -> str: is 5ase0. Therefore it is easy to identify correct answers to the commands. This feature may increase integrity of the communication." """ + with anyio.fail_after(timeout): + while True: + response = await self._read_until_end(timeout=timeout) - start = time.time() - while time.time() - start < timeout: - response = await self._read_until_end(timeout=int(timeout - (time.time() - start))) - - if response[:4] == command[:4].lower(): - return response - - raise TimeoutError("Timeout while waiting for response from device.") + if response[:4] == command[:4].lower(): + return response async def send_command(self, command: str, timeout: int = 3): """Send a command to the device and return the response""" diff --git a/pylabrobot/temperature_controlling/inheco/temperature_controller.py b/pylabrobot/temperature_controlling/inheco/temperature_controller.py index 6c78abb5a40..e188a3f1c12 100644 --- a/pylabrobot/temperature_controlling/inheco/temperature_controller.py +++ b/pylabrobot/temperature_controlling/inheco/temperature_controller.py @@ -1,6 +1,7 @@ import abc import warnings +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.temperature_controlling.backend import TemperatureControllerBackend from pylabrobot.temperature_controlling.inheco.control_box import InhecoTECControlBox @@ -17,11 +18,9 @@ def __init__(self, index: int, control_box: InhecoTECControlBox): self.index = index self.interface = control_box - async def setup(self): - pass - - async def stop(self): - await self.stop_temperature_control() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + stack.push_shielded_async_callback(self.stop_temperature_control) def serialize(self) -> dict: warnings.warn("The interface is not serialized.") diff --git a/pylabrobot/temperature_controlling/opentrons_backend.py b/pylabrobot/temperature_controlling/opentrons_backend.py index 4072ea4ae0e..92d14f30431 100644 --- a/pylabrobot/temperature_controlling/opentrons_backend.py +++ b/pylabrobot/temperature_controlling/opentrons_backend.py @@ -1,5 +1,6 @@ from typing import cast +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.temperature_controlling.backend import ( TemperatureControllerBackend, ) @@ -35,11 +36,9 @@ def __init__(self, opentrons_id: str): f" Import error: {_OT_IMPORT_ERROR}." ) - async def setup(self): - pass - - async def stop(self): - await self.deactivate() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + stack.push_shielded_async_callback(self.deactivate) def serialize(self) -> dict: return {**super().serialize(), "opentrons_id": self.opentrons_id} diff --git a/pylabrobot/temperature_controlling/opentrons_backend_usb.py b/pylabrobot/temperature_controlling/opentrons_backend_usb.py index dd941633d55..455d4f3dda6 100644 --- a/pylabrobot/temperature_controlling/opentrons_backend_usb.py +++ b/pylabrobot/temperature_controlling/opentrons_backend_usb.py @@ -1,5 +1,6 @@ from typing import Optional +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io.serial import Serial from pylabrobot.temperature_controlling.backend import ( TemperatureControllerBackend, @@ -29,21 +30,16 @@ def serial(self) -> "Serial": raise RuntimeError("Serial device not initialized. Call setup() first.") return self._serial - async def setup(self): - # Setup serial communication for USB + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) self._serial = Serial( human_readable_device_name="Opentrons Temperature Module", port=self.port, baudrate=115200, timeout=3, ) - await self._serial.setup() - - async def stop(self): - await self.deactivate() - if self._serial is not None: - await self._serial.stop() - self._serial = None + await stack.enter_async_context(self._serial) + stack.push_shielded_async_callback(self.deactivate) def serialize(self) -> dict: return {**super().serialize(), "port": self.port} diff --git a/pylabrobot/temperature_controlling/temperature_controller.py b/pylabrobot/temperature_controlling/temperature_controller.py index 3b58305c7dc..98965d772f7 100644 --- a/pylabrobot/temperature_controlling/temperature_controller.py +++ b/pylabrobot/temperature_controlling/temperature_controller.py @@ -1,7 +1,8 @@ -import asyncio -import time from typing import Optional +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.machines.machine import Machine from pylabrobot.resources import Coordinate, ResourceHolder @@ -78,13 +79,15 @@ async def wait_for_temperature(self, timeout: float = 300.0, tolerance: float = """ if self.target_temperature is None: raise RuntimeError("Target temperature is not set.") - start = time.time() - while time.time() - start < timeout: - temperature = await self.get_temperature() - if abs(temperature - self.target_temperature) < tolerance: - return - await asyncio.sleep(1.0) - raise TimeoutError(f"Temperature did not reach target temperature within {timeout} seconds.") + try: + with anyio.fail_after(timeout): + while True: + temperature = await self.get_temperature() + if abs(temperature - self.target_temperature) < tolerance: + return + await anyio.sleep(1.0) + except TimeoutError: + raise TimeoutError(f"Temperature did not reach target within {timeout} seconds") from None async def deactivate(self): """Deactivate the temperature controller. This will stop the heating or cooling, and return @@ -93,10 +96,13 @@ async def deactivate(self): self.target_temperature = None return await self.backend.deactivate() - async def stop(self): - """Stop the temperature controller and close the backend connection.""" - await self.deactivate() - await super().stop() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + + async def cleanup(): + await self.deactivate() + + stack.push_shielded_async_callback(cleanup) def serialize(self) -> dict: return { diff --git a/pylabrobot/temperature_controlling/temperature_controller_tests.py b/pylabrobot/temperature_controlling/temperature_controller_tests.py index 5e2a5c8cb86..2f025c14d8d 100644 --- a/pylabrobot/temperature_controlling/temperature_controller_tests.py +++ b/pylabrobot/temperature_controlling/temperature_controller_tests.py @@ -1,14 +1,13 @@ -import unittest - from pylabrobot.resources.coordinate import Coordinate from pylabrobot.temperature_controlling import ( TemperatureController, TemperatureControllerChatterboxBackend, ) from pylabrobot.temperature_controlling.backend import TemperatureControllerBackend +from pylabrobot.testing.concurrency import AnyioTestBase -class TemperatureControllerTests(unittest.TestCase): +class TestTemperatureController(AnyioTestBase): def test_serialization(self): tc = TemperatureController( name="test_tc", @@ -24,7 +23,7 @@ def test_serialization(self): self.assertEqual(tc, deserialized) -class PassiveCoolingTests(unittest.IsolatedAsyncioTestCase): +class TestPassiveCooling(AnyioTestBase): async def test_cannot_cool_without_support(self): backend = TemperatureControllerChatterboxBackend(dummy_temperature=20.0) tc = TemperatureController( @@ -65,12 +64,6 @@ def __init__(self, temperature: float = 25.0): def supports_active_cooling(self) -> bool: return True - async def setup(self): - pass - - async def stop(self): - pass - async def set_temperature(self, temperature: float): self.set_called = True self.temperature = temperature @@ -82,7 +75,7 @@ async def deactivate(self): pass -class PassiveCoolingWithSupportTests(unittest.IsolatedAsyncioTestCase): +class TestPassiveCoolingWithSupport(AnyioTestBase): async def test_passive_cooling_with_support(self): backend = _FakeBackend(temperature=30.0) tc = TemperatureController( diff --git a/pylabrobot/testing/__init__.py b/pylabrobot/testing/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pylabrobot/testing/concurrency.py b/pylabrobot/testing/concurrency.py new file mode 100644 index 00000000000..d4e7b396da6 --- /dev/null +++ b/pylabrobot/testing/concurrency.py @@ -0,0 +1,182 @@ +import inspect +from contextlib import contextmanager + +import anyio +import pytest + +from pylabrobot.concurrency import _AsyncResourceBase + + +def lifespan_kwargs(**kwargs): + def decorator(func): + func._lifespan_kwargs = kwargs + return func + + return decorator + + +# Note: pytest doesn't like classes with __new__, so we use _AsyncResourceBase instead of AsyncResource +class AnyioTestBase(_AsyncResourceBase): + """A test base class enabling structured concurrency. + + Intended as a replacement for `unittest.IsolatedAsyncioTestCase`. + The `unittest` test paradigm of setUp -> test -> tearDown is + fundamentally incompatible with structured concurrency. + + It is recommended to move away from `unittest` and towards `pytest`, + but this class can be used to ease the transition, + by not requiring the test cases to be re-written. + + To convert a test case from `unittest.IsolatedAsyncioTestCase` to `AnyioTestBase`, + you need to replace all `setUp`/`asyncSetUp`/`asyncTearDown`/`tearDown` methods + with a single `_lifespan` context manager method instead. + Then, the test cases themselves can remain unchanged. + + Example + ```python + from contextlib import asynccontextmanager + + from pylabrobot.testing.structured_async import AnyioTestBase + + class TestMyClass(AnyioTestBase): + @asynccontextmanager + async def _lifespan(self): + self.lh = LiquidHandler(...) + async with self.lh: + yield + + def test_my_test(self): + self.assertIsNotNone(self.lh) + ``` + """ + + def __init_subclass__(cls): + def wrap(wrapped): + @pytest.mark.parametrize("backend", ["asyncio", "trio"]) + def sync_wrapper(self, backend, *args, **kwargs): + lifespan_kwargs = getattr(wrapped, "_lifespan_kwargs", {}) + + async def async_wrapper(): + async with self._lifespan(**lifespan_kwargs): + if inspect.iscoroutinefunction(wrapped): + return await wrapped(self, *args, **kwargs) + else: + return wrapped(self, *args, **kwargs) + + return anyio.run(async_wrapper, backend=backend) + + sync_wrapper.original_func = wrapped + return sync_wrapper + + for name, value in list(vars(cls).items()): + if name in {"setUp", "asyncSetUp", "tearDown", "asyncTearDown"}: + raise TypeError( + f"Class {cls.__name__} should not have {name} method, use _lifespan or _enter_lifespan instead." + ) + if name.startswith("test_"): + setattr(cls, name, wrap(value)) + + async def _enter_lifespan(self, stack): + """Helper for the _lifespan implementation; override this instead of _lifespan. + + Note, child classes may add keyword-only arguments to the signature, as _lifespan + forwards those. + """ + pass + + def assertEqual(self, first, second, msg=None): + assert first == second, msg or f"{first} != {second}" + + def assertNotEqual(self, first, second, msg=None): + assert first != second, msg or f"{first} == {second}" + + def assertIn(self, member, container, msg=None): + assert member in container, msg or f"{member!r} not found in {container!r}" + + def assertNotIn(self, member, container, msg=None): + assert member not in container, msg or f"{member!r} found in {container!r}" + + def assertAlmostEqual(self, first, second, places=7, msg=None, delta=None): + if delta is not None: + assert abs(first - second) <= delta, msg or f"{first} != {second} within {delta}" + else: + assert round(abs(first - second), places) == 0, ( + msg or f"{first} != {second} within {places} places" + ) + + def assertNotAlmostEqual(self, first, second, places=7, msg=None, delta=None): + if delta is not None: + assert abs(first - second) > delta, msg or f"{first} == {second} within {delta}" + else: + assert round(abs(first - second), places) != 0, ( + msg or f"{first} == {second} within {places} places" + ) + + def assertIsInstance(self, obj, cls, msg=None): + assert isinstance(obj, cls), msg or f"{obj!r} is not an instance of {cls.__name__}" + + def assertTrue(self, expr, msg=None): + assert expr, msg or f"{expr!r} is not True" + + def assertFalse(self, expr, msg=None): + assert not expr, msg or f"{expr!r} is not False" + + def assertIsNone(self, obj, msg=None): + assert obj is None, msg or f"{obj!r} is not None" + + def assertGreater(self, a, b, msg=None): + assert a > b, msg or f"{a} not greater than {b}" + + def assertIsNotNone(self, obj, msg=None): + assert obj is not None, msg or f"{obj!r} is None" + + @contextmanager + def subTest(self, msg=None, **kwargs): + try: + yield + except Exception as e: + parts = [] + if msg: + parts.append(f"msg={msg}") + if kwargs: + parts.append(", ".join(f"{k}={v}" for k, v in kwargs.items())) + err_msg = f"subTest failed: {', '.join(parts)}" if parts else "subTest failed" + raise AssertionError(f"{err_msg}\nOriginal error: {e}") from e + + @contextmanager + def assertRaises(self, exc_type, exc_value=None, msg=None): + class Context: + def __init__(self): + self.exception = None + + ctx = Context() + try: + yield ctx + except Exception as e: + ctx.exception = e + if not isinstance(e, exc_type): + raise AssertionError(msg or f"Expected exception of type {exc_type.__name__}, got {e!r}") + if exc_value is not None and e != exc_value: + raise AssertionError(msg or f"Expected {exc_value!r}, got {e!r}") + if msg is not None and str(e) != msg: + raise AssertionError(msg or f"Expected {msg}, got {e}") + else: + raise AssertionError(msg or "No exception raised") + + @contextmanager + def assertRaisesRegex(self, exc_type, regex, msg=None): + with self.assertRaises(exc_type) as ctx: + yield ctx + if ctx.exception is not None: + import re + + if not re.search(regex, str(ctx.exception)): + raise AssertionError(msg or f"{regex!r} does not match {str(ctx.exception)!r}") + + @contextmanager + def assertWarns(self, expected_warning): + with pytest.warns(expected_warning): + yield + + def fail(self, msg): + pytest.fail(msg) diff --git a/pylabrobot/testing/mock_io.py b/pylabrobot/testing/mock_io.py new file mode 100644 index 00000000000..705a1ce5910 --- /dev/null +++ b/pylabrobot/testing/mock_io.py @@ -0,0 +1,54 @@ +import anyio + +from pylabrobot.io.io import IOBase + + +class CustomReadMock: + def __init__(self): + self.side_effect = None + + async def __call__(self, *args, **kwargs): + await anyio.sleep(0) + if self.side_effect is None: + return b"" + if isinstance(self.side_effect, list): + if not self.side_effect: + raise IndexError("Mock side effect list exhausted") + return self.side_effect.pop(0) + if callable(self.side_effect): + return self.side_effect(*args, **kwargs) + return self.side_effect + + def reset_mock(self): + self.side_effect = None + + +class CustomWriteMock: + def __init__(self): + self.side_effect = None + + async def __call__(self, data: bytes, *args, **kwargs): + await anyio.sleep(0) + if callable(self.side_effect): + return self.side_effect(data, *args, **kwargs) + + def reset_mock(self): + self.side_effect = None + + +class MockIO(IOBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._read = CustomReadMock() + self._write = CustomWriteMock() + + async def _enter_lifespan(self, stack, **kwargs): + pass + + @property + def write(self): + return self._write + + @property + def read(self): + return self._read diff --git a/pylabrobot/thermocycling/chatterbox.py b/pylabrobot/thermocycling/chatterbox.py index 1c45e40752d..df7dcbab917 100644 --- a/pylabrobot/thermocycling/chatterbox.py +++ b/pylabrobot/thermocycling/chatterbox.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import List, Optional +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.thermocycling.backend import ThermocyclerBackend from pylabrobot.thermocycling.standard import BlockStatus, LidStatus, Protocol @@ -49,11 +50,10 @@ def __init__(self, name: str = "thermocycler_chatterbox", num_zones: int = 1): self._state = ThermocyclerState(num_zones=num_zones) self.num_zones = num_zones - async def setup(self): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) print("Setting up thermocycler.") - - async def stop(self): - print("Stopping thermocycler.") + stack.callback(lambda: print("Stopping thermocycler.")) async def open_lid(self): print("Opening lid.") diff --git a/pylabrobot/thermocycling/chatterbox_tests.py b/pylabrobot/thermocycling/chatterbox_tests.py index dcd85299aae..e7d9310d9db 100644 --- a/pylabrobot/thermocycling/chatterbox_tests.py +++ b/pylabrobot/thermocycling/chatterbox_tests.py @@ -1,15 +1,14 @@ -import unittest from contextlib import redirect_stdout from io import StringIO from pylabrobot.resources import Coordinate +from pylabrobot.testing.concurrency import AnyioTestBase from pylabrobot.thermocycling import Thermocycler, ThermocyclerChatterboxBackend from pylabrobot.thermocycling.standard import Protocol, Stage, Step -class TestThermocyclerChatterbox(unittest.IsolatedAsyncioTestCase): - def __init__(self, methodName="runTest"): - super().__init__(methodName) +class TestThermocyclerChatterbox(AnyioTestBase): + async def _enter_lifespan(self, stack): self.tc = Thermocycler( name="tc_test", size_x=1, diff --git a/pylabrobot/thermocycling/inheco/odtc_backend.py b/pylabrobot/thermocycling/inheco/odtc_backend.py index 84e1b69665b..8982f63d3d3 100644 --- a/pylabrobot/thermocycling/inheco/odtc_backend.py +++ b/pylabrobot/thermocycling/inheco/odtc_backend.py @@ -1,9 +1,11 @@ -import asyncio import datetime import time import xml.etree.ElementTree as ET from typing import Any, Dict, List, Optional +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.storage.inheco.scila.inheco_sila_interface import InhecoSiLAInterface, SiLAError from pylabrobot.thermocycling.backend import ThermocyclerBackend from pylabrobot.thermocycling.standard import BlockStatus, LidStatus, Protocol @@ -49,12 +51,11 @@ def __init__(self, ip: str, client_ip: Optional[str] = None) -> None: self._current_sensors: Dict[str, float] = {} self._temp_update_time: float = 0 - async def setup(self) -> None: - await self._sila_interface.setup() - await self._reset_and_initialize() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding) -> None: + await super()._enter_lifespan(stack) + await stack.enter_async_context(self._sila_interface) - async def stop(self): - await self._sila_interface.close() + await self._reset_and_initialize() async def _reset_and_initialize(self) -> None: try: @@ -68,14 +69,13 @@ async def _reset_and_initialize(self) -> None: async def _wait_for_idle(self, timeout=30): """Wait until device state is not Busy.""" - start = time.time() - while time.time() - start < timeout: - root = await self._sila_interface.send_command("GetStatus") - st = _recursive_find_key(root, "state") - if st and st in ["idle", "standby"]: - return - await asyncio.sleep(1) - raise RuntimeError("Timeout waiting for ODTC idle state") + with anyio.fail_after(timeout): + while True: + root = await self._sila_interface.send_command("GetStatus") + st = _recursive_find_key(root, "state") + if st and st in ["idle", "standby"]: + return + await anyio.sleep(1) # ------------------------------------------------------------------------- # Lid diff --git a/pylabrobot/thermocycling/opentrons_backend.py b/pylabrobot/thermocycling/opentrons_backend.py index 039708f348f..1b2a67b1bbf 100644 --- a/pylabrobot/thermocycling/opentrons_backend.py +++ b/pylabrobot/thermocycling/opentrons_backend.py @@ -2,6 +2,7 @@ from typing import List, Optional, cast +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.thermocycling.backend import ThermocyclerBackend from pylabrobot.thermocycling.standard import BlockStatus, LidStatus, Protocol @@ -45,13 +46,11 @@ def __init__(self, opentrons_id: str): self.opentrons_id = opentrons_id self._current_protocol: Optional[Protocol] = None - async def setup(self): - """No extra setup needed for HTTP-API thermocycler.""" - - async def stop(self): - """Gracefully deactivate both heaters.""" - await self.deactivate_block() - await self.deactivate_lid() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + """Gracefully deactivate both heaters on exit.""" + await super()._enter_lifespan(stack) + stack.push_shielded_async_callback(self.deactivate_lid) + stack.push_shielded_async_callback(self.deactivate_block) def serialize(self) -> dict: """Include the Opentrons module ID in serialized state.""" diff --git a/pylabrobot/thermocycling/opentrons_backend_tests.py b/pylabrobot/thermocycling/opentrons_backend_tests.py index c68460a991e..a2341850640 100644 --- a/pylabrobot/thermocycling/opentrons_backend_tests.py +++ b/pylabrobot/thermocycling/opentrons_backend_tests.py @@ -1,20 +1,27 @@ -import unittest from unittest.mock import patch import pytest -pytest.importorskip("ot_api") +from pylabrobot.testing.concurrency import AnyioTestBase +from pylabrobot.machines.backend import MachineBackend +pytest.importorskip("ot_api") from pylabrobot.resources.itemized_resource import ItemizedResource from pylabrobot.thermocycling.opentrons import OpentronsThermocyclerModuleV1 from pylabrobot.thermocycling.opentrons_backend import OpentronsThermocyclerBackend from pylabrobot.thermocycling.standard import BlockStatus, LidStatus, Protocol, Stage, Step -class TestOpentronsThermocyclerBackend(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.thermocycler_backend = OpentronsThermocyclerBackend(opentrons_id="test_id") +class MockOpentronsThermocyclerBackend(OpentronsThermocyclerBackend): + async def _enter_lifespan(self, stack): + await MachineBackend._enter_lifespan(self, stack) + + +class TestOpentronsThermocyclerBackend(AnyioTestBase): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) + self.thermocycler_backend = MockOpentronsThermocyclerBackend(opentrons_id="test_id") + await stack.enter_async_context(self.thermocycler_backend) def test_opentrons_v1_serialization(self): """Test that the Opentrons-specific resource model serializes correctly.""" diff --git a/pylabrobot/thermocycling/opentrons_backend_usb.py b/pylabrobot/thermocycling/opentrons_backend_usb.py index 41daf9a002c..11f96780a94 100644 --- a/pylabrobot/thermocycling/opentrons_backend_usb.py +++ b/pylabrobot/thermocycling/opentrons_backend_usb.py @@ -4,6 +4,9 @@ import asyncio from typing import List, Optional +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.thermocycling.backend import ThermocyclerBackend from pylabrobot.thermocycling.standard import ( BlockStatus, @@ -46,22 +49,20 @@ async def set_temperature_no_pause( async def wait_for_block_target(driver) -> None: """Wait for block temperature to reach target.""" - max_attempts = 300 # 5 minutes max wait (300 * 1 second) - attempt = 0 - - while attempt < max_attempts: - try: - plate_temp = await driver.get_plate_temperature() - if plate_temp.target is not None and abs(plate_temp.current - plate_temp.target) < 1.0: - break - except Exception as e: - if "invalid thermistor" in str(e).lower() or "error" in str(e).lower(): - raise RuntimeError(f"Thermocycler hardware error: {e}") - print(f"Temperature check failed (attempt {attempt + 1}), retrying: {e}") - attempt += 1 - await asyncio.sleep(1.0) - else: - raise TimeoutError(f"Temperature did not reach target within {max_attempts} seconds") + try: + with anyio.fail_after(300): # 5 minutes max wait (300 * 1 second) + while True: + try: + plate_temp = await driver.get_plate_temperature() + if plate_temp.target is not None and abs(plate_temp.current - plate_temp.target) < 1.0: + break + except Exception as e: + if "invalid thermistor" in str(e).lower() or "error" in str(e).lower(): + raise RuntimeError(f"Thermocycler hardware error: {e}") + print(f"Temperature check failed, retrying: {e}") + await anyio.sleep(1.0) + except TimeoutError: + raise TimeoutError("Temperature did not reach target within 300 seconds") async def execute_cycle_step( @@ -91,15 +92,15 @@ class OpentronsThermocyclerUSBBackend(ThermocyclerBackend): (0x0483, 0xED8D), # STMicroelectronics bridge seen in newer units } - def __init__(self): + def __init__(self, port: Optional[str] = None): """Create a new USB backend.""" super().__init__() if not USE_OPENTRONS_DRIVER: raise RuntimeError("Opentrons thermocycler driver not available") from _import_error + self.port = port self._driver: Optional[AbstractThermocyclerDriver] = None self._current_protocol: Optional[Protocol] = None - self._loop: Optional[asyncio.AbstractEventLoop] = None self._total_cycle_count: Optional[int] = None self._current_cycle_index: Optional[int] = None @@ -188,11 +189,10 @@ async def run_protocol(self, protocol: Protocol, block_max_volume: float): self._current_protocol = protocol - async def setup(self, port: Optional[str] = None): - """Setup the USB connection to the thermocycler.""" - if self._loop is None: - self._loop = asyncio.get_event_loop() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding): + await super()._enter_lifespan(stack) + port = self.port if port is None: ports = serial.tools.list_ports.comports() opentrons_ports = [p for p in ports if (p.vid, p.pid) in self.SUPPORTED_USB_IDS] @@ -208,15 +208,19 @@ async def setup(self, port: Optional[str] = None): else: port = opentrons_ports[0].device - self._driver = await ThermocyclerDriverFactory.create(port, self._loop) + self._driver = await ThermocyclerDriverFactory.create(port, asyncio.get_running_loop()) - async def stop(self): - if self._driver is not None: - await self.deactivate_block() - await self.deactivate_lid() - await self._driver.disconnect() + async def _cleanup(): + try: + await self.deactivate_block() + await self.deactivate_lid() + finally: + assert self._driver is not None + await self._driver.disconnect() self._driver = None + stack.push_shielded_async_callback(_cleanup) + async def open_lid(self): assert self._driver is not None await self._driver.open_lid() diff --git a/pylabrobot/thermocycling/thermo_fisher/proflex_tests.py b/pylabrobot/thermocycling/thermo_fisher/proflex_tests.py index 0f37d0ba393..746181c2dae 100644 --- a/pylabrobot/thermocycling/thermo_fisher/proflex_tests.py +++ b/pylabrobot/thermocycling/thermo_fisher/proflex_tests.py @@ -2,13 +2,14 @@ import unittest import unittest.mock +from pylabrobot.testing.concurrency import AnyioTestBase from pylabrobot.thermocycling.standard import Protocol, Stage, Step from pylabrobot.thermocycling.thermo_fisher.proflex import ProflexBackend -class TestProflexBackend(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() +class TestProflexBackend(AnyioTestBase): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.proflex = ProflexBackend(ip="1.2.3.4") self.proflex.num_temp_zones = 2 self.proflex.io.write = unittest.mock.AsyncMock() # type: ignore diff --git a/pylabrobot/thermocycling/thermo_fisher/thermo_fisher_thermocycler.py b/pylabrobot/thermocycling/thermo_fisher/thermo_fisher_thermocycler.py index 86bab4c0e3d..257dfdc5624 100644 --- a/pylabrobot/thermocycling/thermo_fisher/thermo_fisher_thermocycler.py +++ b/pylabrobot/thermocycling/thermo_fisher/thermo_fisher_thermocycler.py @@ -1,4 +1,3 @@ -import asyncio import hashlib import hmac import logging @@ -12,6 +11,9 @@ from typing import Any, Dict, List, Optional, cast from xml.dom import minidom +import anyio + +from pylabrobot.concurrency import AsyncExitStackWithShielding from pylabrobot.io import Socket from pylabrobot.thermocycling.backend import ThermocyclerBackend from pylabrobot.thermocycling.standard import LidStatus, Protocol, Stage, Step @@ -423,7 +425,6 @@ async def send_command(self, data, response_timeout=1, read_once=True): return await self._read_response(timeout=response_timeout, read_once=read_once) async def _scpi_authenticate(self): - await self.io.setup() await self._read_response(timeout=5) challenge_res = await self.send_command({"cmd": "CHAL?"}) challenge = self._parse_scpi_response(challenge_res)["args"][0] @@ -614,19 +615,19 @@ async def send_morse_code(self, morse_code: str): for char in morse_code: if char == ".": await self.buzzer_on() - await asyncio.sleep(short_beep_duration) + await anyio.sleep(short_beep_duration) await self.buzzer_off() elif char == "-": await self.buzzer_on() - await asyncio.sleep(long_beep_duration) + await anyio.sleep(long_beep_duration) await self.buzzer_off() elif char == " ": - await asyncio.sleep(space_duration) - await asyncio.sleep(short_beep_duration) # between letters is a short unit + await anyio.sleep(space_duration) + await anyio.sleep(short_beep_duration) # between letters is a short unit async def continue_run(self, block_id: int): for _ in range(3): - await asyncio.sleep(1) + await anyio.sleep(1) res = await self.send_command({"cmd": f"TBC{block_id + 1}:CONTinue"}) if self._parse_scpi_response(res)["status"] != "OK": raise ValueError("Failed to continue from indefinite hold") @@ -816,7 +817,7 @@ async def abort_run(self, block_id: int): self.logger.error("Failed to abort protocol") raise ValueError("Failed to abort protocol") self.logger.info("Protocol aborted") - await asyncio.sleep(10) + await anyio.sleep(10) @dataclass class RunProgress: @@ -870,7 +871,7 @@ async def get_run_info(self, protocol: Protocol, block_id: int) -> "RunProgress" abs(float(block_temps[i]) - target_temps[i]) < 0.5 for i in range(len(block_temps)) ): break - await asyncio.sleep(5) + await anyio.sleep(5) self.logger.info("Infinite hold") return ThermoFisherThermocyclerBackend.RunProgress( running=False, @@ -890,9 +891,34 @@ async def get_run_info(self, protocol: Protocol, block_id: int) -> "RunProgress" # *************Methods implementing ThermocyclerBackend*********************** - async def setup( - self, block_idle_temp=25, cover_idle_temp=105, blocks_to_setup: Optional[List[int]] = None + async def _enter_lifespan( + self, + stack: AsyncExitStackWithShielding, + *, + block_idle_temp=25, + cover_idle_temp=105, + blocks_to_setup: Optional[List[int]] = None, ): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) + + async def cleanup(): + for block_id in list(self.current_runs.keys()): + try: + await self.abort_run(block_id=block_id) + except Exception as e: + self.logger.warning(f"Failed to abort run on block {block_id}: {e}") + try: + await self.deactivate_lid(block_id=block_id) + except Exception as e: + self.logger.warning(f"Failed to deactivate lid on block {block_id}: {e}") + try: + await self.deactivate_block(block_id=block_id) + except Exception as e: + self.logger.warning(f"Failed to deactivate block on block {block_id}: {e}") + + stack.push_shielded_async_callback(cleanup) + await self._scpi_authenticate() await self.power_on() await self._load_num_blocks_and_type() @@ -979,14 +1005,7 @@ async def run_protocol( stage_name_prefixes=stage_name_prefixes, ) - async def stop(self): - for block_id in self.current_runs.keys(): - await self.abort_run(block_id=block_id) - - await self.deactivate_lid(block_id=block_id) - await self.deactivate_block(block_id=block_id) - - await self.io.stop() + # stop method removed, logic moved to cleanup callback in _enter_lifespan async def get_block_status(self, *args, **kwargs): raise NotImplementedError diff --git a/pylabrobot/thermocycling/thermocycler.py b/pylabrobot/thermocycling/thermocycler.py index 622599d47a2..319e81a367c 100644 --- a/pylabrobot/thermocycling/thermocycler.py +++ b/pylabrobot/thermocycling/thermocycler.py @@ -1,9 +1,9 @@ """High-level Thermocycler resource wrapping a backend.""" -import asyncio -import time from typing import List, Optional +import anyio + from pylabrobot.machines.machine import Machine from pylabrobot.resources import Coordinate, ResourceHolder from pylabrobot.thermocycling.backend import ThermocyclerBackend @@ -225,13 +225,17 @@ async def get_total_step_count(self, **backend_kwargs) -> int: async def wait_for_block(self, timeout: float = 600, tolerance: float = 0.5, **backend_kwargs): """Wait until block temp reaches target ± tolerance for all zones.""" targets = await self.get_block_target_temperature(**backend_kwargs) - start = time.time() - while time.time() - start < timeout: - currents = await self.get_block_current_temperature(**backend_kwargs) - if all(abs(current - target) < tolerance for current, target in zip(currents, targets)): - return - await asyncio.sleep(1) - raise TimeoutError("Block temperature timeout.") + try: + with anyio.fail_after(timeout): + while True: + currents = await self.get_block_current_temperature(**backend_kwargs) + if all(abs(current - target) < tolerance for current, target in zip(currents, targets)): + return + await anyio.sleep(1) + except TimeoutError: + raise TimeoutError( + f"Block temperature did not reach target within {timeout} seconds" + ) from None async def wait_for_lid(self, timeout: float = 1200, tolerance: float = 0.5, **backend_kwargs): """Wait until the lid temperature reaches target ± ``tolerance`` or the lid temperature status is idle/holding at target.""" @@ -239,19 +243,22 @@ async def wait_for_lid(self, timeout: float = 1200, tolerance: float = 0.5, **ba targets = await self.get_lid_target_temperature(**backend_kwargs) except RuntimeError: targets = None - start = time.time() - while time.time() - start < timeout: - if targets is not None: - currents = await self.get_lid_current_temperature(**backend_kwargs) - if all(abs(current - target) < tolerance for current, target in zip(currents, targets)): - return - else: - # If no target temperature, check status - status = await self.get_lid_status(**backend_kwargs) - if status in ["idle", "holding at target"]: - return - await asyncio.sleep(1) - raise TimeoutError("Lid temperature timeout.") + + try: + with anyio.fail_after(timeout): + while True: + if targets is not None: + currents = await self.get_lid_current_temperature(**backend_kwargs) + if all(abs(current - target) < tolerance for current, target in zip(currents, targets)): + return + else: + # If no target temperature, check status + status = await self.get_lid_status(**backend_kwargs) + if status in ["idle", "holding at target"]: + return + await anyio.sleep(1) + except TimeoutError: + raise TimeoutError(f"Lid temperature did not reach target within {timeout} seconds") from None async def is_profile_running(self, **backend_kwargs) -> bool: """Return True if a profile is still in progress.""" @@ -275,7 +282,7 @@ async def is_profile_running(self, **backend_kwargs) -> bool: async def wait_for_profile_completion(self, poll_interval: float = 60.0, **backend_kwargs): """Block until the profile finishes, polling at `poll_interval` seconds.""" while await self.is_profile_running(**backend_kwargs): - await asyncio.sleep(poll_interval) + await anyio.sleep(poll_interval) def serialize(self) -> dict: """JSON-serializable representation.""" diff --git a/pylabrobot/thermocycling/thermocycler_tests.py b/pylabrobot/thermocycling/thermocycler_tests.py index db0d4821c6f..c2154ce43f4 100644 --- a/pylabrobot/thermocycling/thermocycler_tests.py +++ b/pylabrobot/thermocycling/thermocycler_tests.py @@ -1,8 +1,9 @@ -import asyncio -import unittest from unittest.mock import AsyncMock, MagicMock +import anyio + from pylabrobot.resources import Coordinate +from pylabrobot.testing.concurrency import AnyioTestBase from pylabrobot.thermocycling import ( Thermocycler, ThermocyclerBackend, @@ -38,9 +39,9 @@ def mock_backend() -> MagicMock: return mock -class ThermocyclerTests(unittest.IsolatedAsyncioTestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class TestThermocycler(AnyioTestBase): + async def _enter_lifespan(self, stack): + await super()._enter_lifespan(stack) self.tc = Thermocycler( name="test_tc", size_x=10, @@ -49,6 +50,7 @@ def __init__(self, *args, **kwargs): backend=mock_backend(), child_location=Coordinate(0, 0, 0), ) + await stack.enter_async_context(self.tc) def test_thermocycler_serialization(self): """Test that the high-level resource serializes and deserializes correctly.""" @@ -107,21 +109,27 @@ async def test_wait_for_profile_completion(self): """Test that wait_for_profile_completion correctly polls is_profile_running.""" self.tc.backend.get_hold_time.side_effect = [10.0, 5.0, 0.0] # type: ignore - # Patch asyncio.sleep to a no-op for the test. - original_sleep = asyncio.sleep + # Patch anyio.sleep to a no-op for the test. + original_sleep = anyio.sleep async def mock_sleep(*args, **kwargs): pass - asyncio.sleep = mock_sleep + anyio.sleep = mock_sleep try: await self.tc.wait_for_profile_completion(poll_interval=0.01) assert self.tc.backend.get_hold_time.call_count == 3 # type: ignore finally: - asyncio.sleep = original_sleep + anyio.sleep = original_sleep async def test_is_profile_running_logic(self): """Test that `is_profile_running` returns the correct boolean based on various profile states.""" + # Justification for 0-based indexing test cases: + # The implementation of `is_profile_running()` relies on zero-based indexing for cycles and steps. + # The original test cases in `main` used 1-based values (e.g., testing step 3 out of 3), which were + # out-of-bounds. They passed by accident because out-of-bounds values failed the boundary checks. + # We corrected them to use the highest valid 0-based index (e.g., step 2 out of 3) to accurately + # test the boundary conditions. test_cases = [ (10.0, 1, 10, 1, 3, True), (0.0, 5, 10, 1, 3, True), diff --git a/pylabrobot/tilting/chatterbox.py b/pylabrobot/tilting/chatterbox.py index 5ef3ea93be0..65809bc607e 100644 --- a/pylabrobot/tilting/chatterbox.py +++ b/pylabrobot/tilting/chatterbox.py @@ -1,12 +1,16 @@ +import contextlib + from pylabrobot.tilting import TilterBackend class TilterChatterboxBackend(TilterBackend): - async def setup(self): + async def _enter_lifespan(self, stack: contextlib.AsyncExitStack): print("Setting up tilter.") - async def stop(self): - print("Stopping tilter.") + def _cleanup(): + print("Stopping tilter.") + + stack.callback(_cleanup) async def set_angle(self, angle: float): print(f"Setting the angle to {angle}.") diff --git a/pylabrobot/tilting/hamilton_backend.py b/pylabrobot/tilting/hamilton_backend.py index 34ea0337d8b..5effb345d52 100644 --- a/pylabrobot/tilting/hamilton_backend.py +++ b/pylabrobot/tilting/hamilton_backend.py @@ -1,6 +1,8 @@ import re from typing import Optional +from pylabrobot.concurrency import AsyncExitStackWithShielding + try: import serial @@ -42,14 +44,12 @@ def __init__( human_readable_device_name="Hamilton Tilt Module", ) - async def setup(self, initial_offset: int = 0): - await self.io.setup() + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, initial_offset: int = 0): + await super()._enter_lifespan(stack) + await stack.enter_async_context(self.io) await self.tilt_initial_offset(initial_offset) await self.tilt_initialize() - async def stop(self): - await self.io.stop() - async def send_command(self, command: str, parameter: Optional[str] = None) -> str: """Send a command to the tilt module.""" @@ -303,11 +303,10 @@ async def tilt_initial_offset(self, offset: int): class HamiltonTiltModuleChatterboxBackend(HamiltonTiltModuleBackend): - async def setup(self, initial_offset=0): + async def _enter_lifespan(self, stack: AsyncExitStackWithShielding, *, initial_offset=0): + await super()._enter_lifespan(stack, initial_offset=initial_offset) print(f"[tilter] setup initial offset {initial_offset}") - - async def stop(self): - print("[tilter] stopping") + stack.callback(lambda: print("[tilter] stopping")) async def send_command(self, command, parameter=None): print(f"[tilter] Sending command: {command} with parameter: {parameter}") diff --git a/pyproject.toml b/pyproject.toml index a1225c3cd47..fa090008fe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" requires-python = ">=3.9" license = {text = "MIT"} dynamic = ["version"] -dependencies = ["typing_extensions", "websockets"] +dependencies = ["typing_extensions", "websockets", "anyio", "sniffio"] [project.optional-dependencies] serial = ["pyserial"] @@ -25,6 +25,7 @@ all = ["PyLabRobot[serial,usb,ftdi,hid,modbus,websockets,visualizer,opentrons,si test = [ "pytest", "pytest-timeout", + "trio", ] dev = [ "PyLabRobot[all,test]", diff --git a/pytest.ini b/pytest.ini index 2750f4aa22c..45902db833d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,4 +3,4 @@ python_files = *_tests.py markers = hardware: tests requiring connected devices -addopts = -m "not hardware" \ No newline at end of file +addopts = -m "not hardware"