diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c3642cc..5d39b00 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,5 +57,7 @@ repos: hooks: - id: mypy exclude: cli.py - additional_dependencies: [ "pydantic>=2.0.0", "pytest>=8.0.0" ] + additional_dependencies: + - "pydantic>=2.0.0" + - "pytest>=8.0.0" args: [ "--config-file=./pyproject.toml"] diff --git a/pyomnilogic_local/api/protocol.py b/pyomnilogic_local/api/protocol.py index 3e383bc..d2a8dbd 100644 --- a/pyomnilogic_local/api/protocol.py +++ b/pyomnilogic_local/api/protocol.py @@ -229,7 +229,12 @@ async def _wait_for_ack(self, ack_id: int) -> None: # Wait for either a message or an error data_task = asyncio.create_task(self.data_queue.get()) error_task = asyncio.create_task(self.error_queue.get()) - done, _ = await asyncio.wait([data_task, error_task], return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait([data_task, error_task], return_when=asyncio.FIRST_COMPLETED) + + # Cancel any pending tasks to avoid "Task was destroyed but it is pending" warnings + for task in pending: + task.cancel() + if error_task in done: exc = error_task.result() if isinstance(exc, Exception): diff --git a/pyomnilogic_local/omnilogic.py b/pyomnilogic_local/omnilogic.py index f92687f..b12fc64 100644 --- a/pyomnilogic_local/omnilogic.py +++ b/pyomnilogic_local/omnilogic.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from pyomnilogic_local._base import OmniEquipment + from pyomnilogic_local.bow import Bow from pyomnilogic_local.chlorinator import Chlorinator from pyomnilogic_local.chlorinator_equip import ChlorinatorEquipment from pyomnilogic_local.colorlogiclight import ColorLogicLight @@ -260,6 +261,12 @@ def all_csads(self) -> EquipmentDict[CSAD]: csads.extend(bow.csads.values()) return EquipmentDict(csads) + @property + def all_bows(self) -> EquipmentDict[Bow]: + """Returns all Bow instances across all bows in the backyard.""" + # Bows are stored directly in backyard as EquipmentDict already + return self.backyard.bow + # Equipment search methods def get_equipment_by_name(self, name: str) -> OmniEquipment[Any, Any] | None: """Find equipment by name across all equipment types. @@ -272,6 +279,7 @@ def get_equipment_by_name(self, name: str) -> OmniEquipment[Any, Any] | None: """ # Search all equipment types all_equipment: list[OmniEquipment[Any, Any]] = [] + all_equipment.extend([self.backyard]) all_equipment.extend(self.all_lights.values()) all_equipment.extend(self.all_relays.values()) all_equipment.extend(self.all_pumps.values()) @@ -283,6 +291,7 @@ def get_equipment_by_name(self, name: str) -> OmniEquipment[Any, Any] | None: all_equipment.extend(self.all_chlorinator_equipment.values()) all_equipment.extend(self.all_csads.values()) all_equipment.extend(self.all_csad_equipment.values()) + all_equipment.extend(self.all_bows.values()) all_equipment.extend(self.groups.values()) for equipment in all_equipment: @@ -302,6 +311,7 @@ def get_equipment_by_id(self, system_id: int) -> OmniEquipment[Any, Any] | None: """ # Search all equipment types all_equipment: list[OmniEquipment[Any, Any]] = [] + all_equipment.extend([self.backyard]) all_equipment.extend(self.all_lights.values()) all_equipment.extend(self.all_relays.values()) all_equipment.extend(self.all_pumps.values()) @@ -313,6 +323,7 @@ def get_equipment_by_id(self, system_id: int) -> OmniEquipment[Any, Any] | None: all_equipment.extend(self.all_chlorinator_equipment.values()) all_equipment.extend(self.all_csads.values()) all_equipment.extend(self.all_csad_equipment.values()) + all_equipment.extend(self.all_bows.values()) all_equipment.extend(self.groups.values()) all_equipment.extend(self.schedules.values()) diff --git a/pyproject.toml b/pyproject.toml index 7c75430..e941f97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ python_version = "3.13" plugins = [ "pydantic.mypy" ] -follow_imports = "silent" +# follow_imports = "silent" strict = true ignore_missing_imports = true disallow_subclassing_any = false diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d4c2459..512b318 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -14,7 +14,7 @@ import struct import time import zlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock, MagicMock, patch from xml.etree import ElementTree as ET @@ -787,3 +787,41 @@ async def test_receive_file_fragmented_ignores_non_block_messages(caplog: pytest assert any("other than a blockmessage" in r.message for r in caplog.records) assert result == "data" + + +@pytest.mark.asyncio +async def test_wait_for_ack_cancels_pending_tasks() -> None: + """Test that pending tasks are properly cancelled in _wait_for_ack to avoid warnings.""" + protocol = OmniLogicProtocol() + protocol.transport = MagicMock() + + # Track tasks created during _wait_for_ack + created_tasks: list[asyncio.Task[Any]] = [] + original_create_task = asyncio.create_task + + def track_create_task(coro: Any) -> asyncio.Task[Any]: + task: asyncio.Task[Any] = original_create_task(coro) + created_tasks.append(task) + return task + + # Queue up an ACK message + ack_msg = OmniLogicMessage(42, MessageType.ACK) + await protocol.data_queue.put(ack_msg) + + # Patch create_task to track tasks + with patch("asyncio.create_task", side_effect=track_create_task): + await protocol._wait_for_ack(42) + + # Give the event loop a chance to process cancellation + await asyncio.sleep(0) + + # Should have created 2 tasks (data_task and error_task) + assert len(created_tasks) == 2 + + # One should be done (the data_task that got the ACK) + # One should be cancelled (the error_task that was waiting) + done_tasks = [t for t in created_tasks if t.done() and not t.cancelled()] + cancelled_tasks = [t for t in created_tasks if t.cancelled()] + + assert len(done_tasks) == 1, "Expected exactly one task to complete normally" + assert len(cancelled_tasks) == 1, "Expected exactly one task to be cancelled"