From 321aa5a10c06a665552c8c05fff259900afaf209 Mon Sep 17 00:00:00 2001 From: Jay Newstrom Date: Sat, 9 May 2026 19:26:12 -0600 Subject: [PATCH 1/2] allow @webhook_trigger functions to control the http response Webhook handlers now wait for the decorated function and use its return value to build the response: int -> Response(status=...), aiohttp.web.Response -> as-is, None -> default 200. When multiple triggers share a webhook_id, the first non-None return wins. Co-Authored-By: Claude Opus 4.7 --- custom_components/pyscript/decorator.py | 4 + custom_components/pyscript/decorator_abc.py | 11 ++ .../pyscript/decorators/webhook.py | 55 +++++- .../pyscript/stubs/pyscript_builtins.py | 2 + docs/reference.rst | 12 ++ tests/test_decorator_manager.py | 80 +++++++- tests/test_decorators.py | 171 +++++++++++++++++- 7 files changed, 328 insertions(+), 7 deletions(-) diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py index 33af0d3..8eb7898 100644 --- a/custom_components/pyscript/decorator.py +++ b/custom_components/pyscript/decorator.py @@ -264,6 +264,7 @@ async def _call(self, data: DispatchData) -> None: # notify handlers with "None" for result_handler_dec in result_handlers: await result_handler_dec.handle_call_result(data, None) + data.set_result(None) return # Fire an event indicating that pyscript is running # Note: the event must have an entity_id for logbook to work correctly. @@ -279,7 +280,9 @@ async def _call(self, data: DispatchData) -> None: result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) for result_handler_dec in result_handlers: await result_handler_dec.handle_call_result(data, result) + data.set_result(result) except Exception as e: + data.set_result(None) await self.handle_exception(e) async def dispatch(self, data: DispatchData) -> None: @@ -290,6 +293,7 @@ async def dispatch(self, data: DispatchData) -> None: for dec in decorators: if await dec.handle_dispatch(data) is False: self.logger.debug("Trigger not active due to %s", dec) + data.set_result(None) return action_ast_ctx = AstEval( diff --git a/custom_components/pyscript/decorator_abc.py b/custom_components/pyscript/decorator_abc.py index 4775317..d24005b 100644 --- a/custom_components/pyscript/decorator_abc.py +++ b/custom_components/pyscript/decorator_abc.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +import asyncio from dataclasses import dataclass, field from enum import StrEnum import logging @@ -46,6 +47,16 @@ class DispatchData: call_ast_ctx: AstEval | None = field(default=None, kw_only=True) hass_context: Context | None = field(default=None, kw_only=True) + # When set, the dispatch pipeline resolves this future with the + # decorated function's return value. Resolved with None if the + # function is skipped (guard rejection) or raises. + result_future: asyncio.Future[Any] | None = field(default=None, kw_only=True) + + def set_result(self, value: Any) -> None: + """Resolve result_future with value if it is still pending.""" + if self.result_future is not None and not self.result_future.done(): + self.result_future.set_result(value) + class Decorator(ABC): """Generic decorator abstraction.""" diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py index 3db0a09..ea72632 100644 --- a/custom_components/pyscript/decorators/webhook.py +++ b/custom_components/pyscript/decorators/webhook.py @@ -2,10 +2,11 @@ from __future__ import annotations +import asyncio import logging -from typing import ClassVar +from typing import Any, ClassVar -from aiohttp import hdrs +from aiohttp import hdrs, web import voluptuous as vol from homeassistant.components import webhook @@ -32,12 +33,14 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD { vol.Optional("local_only", default=True): cv.boolean, vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]), + vol.Optional("sets_http_response_code", default=False): cv.boolean, } ) webhook_id: str local_only: bool methods: set[str] + sets_http_response_code: bool webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {} @@ -50,7 +53,7 @@ async def validate(self): self.create_expression(self.args[1]) @staticmethod - async def _handler(_hass, webhook_id, request): + async def _handler(hass, webhook_id, request): func_args = { "trigger_type": "webhook", "webhook_id": webhook_id, @@ -64,17 +67,59 @@ async def _handler(_hass, webhook_id, request): payload_multidict = await request.post() func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + response_future: asyncio.Future[Any] | None = None + futures: list[asyncio.Future[Any]] = [] for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy(): trigger_args = func_args.copy() if trigger.has_expression(): if not await trigger.check_expression_vars(trigger_args): continue - await trigger.dispatch(DispatchData(trigger_args)) + future: asyncio.Future[Any] = hass.loop.create_future() + if trigger.sets_http_response_code: + response_future = future + futures.append(future) + await trigger.dispatch(DispatchData(trigger_args, result_future=future)) + + if not futures: + return None + + await asyncio.gather(*futures, return_exceptions=True) + + if response_future is None: + return None + return WebhookTriggerDecorator.coerce_response(response_future.result()) + + @staticmethod + def coerce_response(value: Any) -> web.Response | None: + """Convert a webhook function return value to an aiohttp Response.""" + if value is None: + return None + if isinstance(value, web.Response): + return value + # bool is a subclass of int; reject it so True/False don't become 1/0 status codes. + if isinstance(value, int) and not isinstance(value, bool): + return web.Response(status=value) + _LOGGER.warning( + "webhook function returned unsupported type %s; expected int status code or aiohttp.web.Response", + type(value).__name__, + ) + return None @staticmethod def _add_trigger(trigger: WebhookTriggerDecorator) -> None: webhook_id = trigger.webhook_id - if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers: + existing = WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id) + if ( + trigger.sets_http_response_code + and existing is not None + and any(t.sets_http_response_code for t in existing) + ): + raise ValueError( + f"webhook_id '{webhook_id}' already has a @webhook_trigger with " + f"sets_http_response_code=True; only one is allowed" + ) + + if existing is None: webhook.async_register( trigger.dm.hass, "pyscript", # DOMAIN diff --git a/custom_components/pyscript/stubs/pyscript_builtins.py b/custom_components/pyscript/stubs/pyscript_builtins.py index ea75580..61d3517 100644 --- a/custom_components/pyscript/stubs/pyscript_builtins.py +++ b/custom_components/pyscript/stubs/pyscript_builtins.py @@ -127,6 +127,7 @@ def webhook_trigger( str_expr: str | None = None, local_only: bool = True, methods: set[SUPPORTED_METHODS] | list[SUPPORTED_METHODS] = {"POST", "PUT"}, + sets_http_response_code: bool = False, kwargs: dict | None = None, ) -> Callable[..., Any]: """Trigger when a request is made to a webhook endpoint. @@ -136,6 +137,7 @@ def webhook_trigger( str_expr: Optional expression evaluated against ``trigger_type``, ``webhook_id``, ``request``, and ``payload``. local_only: If False, allow requests from anywhere on the internet. methods: HTTP methods to allow. + sets_http_response_code: If True, the function's return value drives the HTTP response (``int`` status code or ``aiohttp.web.Response``); at most one trigger per ``webhook_id`` may set this. kwargs: Extra keyword arguments merged into each invocation. Trigger kwargs include ``trigger_type="webhook"``, ``webhook_id``, the parsed payload fields, and ``request`` (the underlying ``aiohttp.web.Request``). diff --git a/docs/reference.rst b/docs/reference.rst index 3b7c587..2781215 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -915,6 +915,18 @@ To validate an HMAC signature on incoming requests, declare ``request`` in the f return log.info(f"verified webhook: {payload}") +To control the HTTP response sent back to the webhook caller, opt in by passing ``sets_http_response_code=True``. The flagged function's return value then drives the response: ``None`` produces a ``200 OK``, an ``int`` sends back a response with that status code, and an ``aiohttp.web.Response`` allows full control over the body and headers. Return values from triggers without the flag are ignored. For example: + +.. code:: python + + @webhook_trigger("myid", sets_http_response_code=True) + def webhook_check(payload): + if "token" not in payload: + return 401 + return 204 + +At most one ``@webhook_trigger`` per ``webhook_id`` may set ``sets_http_response_code=True``; declaring more than one is an error at setup time. The webhook handler waits for all decorated function(s) for the ``webhook_id`` to finish before responding, so use ``task.create()`` to fire-and-forget any long-running work. + NOTE: A webhook_id can only be used by either a built-in Home Assistant automation or pyscript, but not both. Trying to use the same webhook_id in both will result in an error. @state_active diff --git a/tests/test_decorator_manager.py b/tests/test_decorator_manager.py index 45c6f89..f904871 100644 --- a/tests/test_decorator_manager.py +++ b/tests/test_decorator_manager.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from typing import ClassVar from unittest.mock import patch @@ -353,9 +354,15 @@ def make_dispatch_data( *, call_ast_ctx: DummyCallAstCtx | None = None, hass_context: Context | None = None, + result_future: asyncio.Future | None = None, ) -> DispatchData: """Build DispatchData from test doubles.""" - return DispatchData(func_args, call_ast_ctx=call_ast_ctx, hass_context=hass_context) + return DispatchData( + func_args, + call_ast_ctx=call_ast_ctx, + hass_context=hass_context, + result_future=result_future, + ) def setup_global_context_function_hass(hass: HomeAssistant, config_data: dict | None = None) -> None: @@ -599,6 +606,77 @@ async def test_function_decorator_manager_logs_call_exception(hass): assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed" +@pytest.mark.asyncio +async def test_function_decorator_manager_result_future_success(hass): + """Successful calls should resolve result_future with the function's return value.""" + DecoratorManager.hass = hass + manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) + call_ast_ctx = DummyCallAstCtx(result=201) + future: asyncio.Future = hass.loop.create_future() + + with patch.object(Function, "store_hass_context"): + await call_function_manager( + manager, + make_dispatch_data( + {"arg1": 1}, + call_ast_ctx=call_ast_ctx, + hass_context=Context(id="call-parent"), + result_future=future, + ), + ) + await hass.async_block_till_done() + + assert future.done() + assert future.result() == 201 + + +@pytest.mark.asyncio +async def test_function_decorator_manager_result_future_cancel(hass): + """When a call handler cancels, result_future should resolve to None.""" + DecoratorManager.hass = hass + manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) + manager.add(make_cancel_call_handler()) + future: asyncio.Future = hass.loop.create_future() + + await call_function_manager( + manager, + make_dispatch_data( + {"arg1": 1}, + call_ast_ctx=DummyCallAstCtx(result="unused"), + hass_context=Context(id="call-parent"), + result_future=future, + ), + ) + + assert future.done() + assert future.result() is None + + +@pytest.mark.asyncio +async def test_function_decorator_manager_result_future_exception(hass): + """When the decorated function raises, result_future should resolve to None.""" + DecoratorManager.hass = hass + ast_ctx = DummyAstCtx() + manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("boom")) + future: asyncio.Future = hass.loop.create_future() + + with patch.object(Function, "store_hass_context"): + await call_function_manager( + manager, + make_dispatch_data( + {"arg1": 1}, + call_ast_ctx=call_ast_ctx, + hass_context=Context(id="call-parent"), + result_future=future, + ), + ) + + assert future.done() + assert future.result() is None + assert len(ast_ctx.logged_exceptions) == 1 + + def test_decorator_registry_register_requires_name(): """Registry should reject decorators without a declared name.""" diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 12224d4..a214787 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -3,12 +3,15 @@ from ast import literal_eval import asyncio from datetime import datetime as dt -from unittest.mock import mock_open, patch +from http import HTTPStatus +from unittest.mock import Mock, mock_open, patch +from aiohttp import web import pytest from custom_components.pyscript import trigger from custom_components.pyscript.const import DOMAIN +from custom_components.pyscript.decorators.webhook import WebhookTriggerDecorator from custom_components.pyscript.function import Function from homeassistant.components import webhook from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED @@ -256,3 +259,169 @@ def webhook_test(payload, request): await webhook.async_handle_webhook(hass, "test_req_hook", request) assert literal_eval(await wait_until_done(notify_q)) == ["abc123", "POST", {"hello": "world"}] + + +def _post_webhook_request() -> MockRequest: + """Build a MockRequest representing a webhook POST with form data.""" + return MockRequest( + content=b"", + headers={}, + method="POST", + query_string="", + mock_source="test", + remote="127.0.0.1", + ) + + +def test_webhook_coerce_response_none(): + """A None return should fall through to the HA default response.""" + assert WebhookTriggerDecorator.coerce_response(None) is None + + +def test_webhook_coerce_response_int(): + """Int returns should produce an aiohttp Response with that status.""" + response = WebhookTriggerDecorator.coerce_response(HTTPStatus.CREATED.value) + assert isinstance(response, web.Response) + assert response.status == HTTPStatus.CREATED + + +def test_webhook_coerce_response_passthrough(): + """An aiohttp Response should be returned unchanged.""" + custom = web.Response(status=HTTPStatus.ACCEPTED, body=b"queued") + assert WebhookTriggerDecorator.coerce_response(custom) is custom + + +def test_webhook_coerce_response_bool_warns(caplog): + """Bool returns should be rejected so True/False don't masquerade as 1/0.""" + assert WebhookTriggerDecorator.coerce_response(True) is None + assert "unsupported type bool" in caplog.text + + +def test_webhook_coerce_response_unsupported_warns(caplog): + """Other return types should warn and fall through.""" + assert WebhookTriggerDecorator.coerce_response("ok") is None + assert "unsupported type str" in caplog.text + + +@pytest.mark.asyncio +async def test_webhook_function_returns_status_code(hass): + """A flagged webhook function returning an int should set the HTTP status.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("status_hook", sets_http_response_code=True) +def func_status(payload): + return 201 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "status_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.CREATED + + +@pytest.mark.asyncio +async def test_webhook_function_default_response(hass): + """A pyscript webhook function returning None should produce a 200 OK.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("default_hook") +def func_default(payload): + pass +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "default_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_webhook_unflagged_return_ignored(hass): + """A return value from a trigger without sets_http_response_code is ignored.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("unflagged_hook") +def func_unflagged(payload): + return 418 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "unflagged_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_webhook_only_flagged_trigger_controls_response(hass): + """When multiple triggers share an id, only the flagged one drives the response.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("multi_hook") +def func_silent(payload): + return 500 + +@webhook_trigger("multi_hook", sets_http_response_code=True) +def func_loud(payload): + return 418 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "multi_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == 418 + + +@pytest.mark.asyncio +async def test_webhook_flag_is_per_webhook_id(hass): + """Different webhook_ids may each have their own sets_http_response_code=True trigger.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("hook_a", sets_http_response_code=True) +def func_a(payload): + return 201 + +@webhook_trigger("hook_b", sets_http_response_code=True) +def func_b(payload): + return 202 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response_a = await webhook.async_handle_webhook(hass, "hook_a", _post_webhook_request()) + response_b = await webhook.async_handle_webhook(hass, "hook_b", _post_webhook_request()) + await hass.async_block_till_done() + assert response_a.status == HTTPStatus.CREATED + assert response_b.status == HTTPStatus.ACCEPTED + + +def test_webhook_multiple_flagged_triggers_fails_at_setup(): + """A second flagged trigger for the same webhook_id should be rejected at start.""" + first = Mock(webhook_id="dup_hook", sets_http_response_code=True) + second = Mock(webhook_id="dup_hook", sets_http_response_code=True) + + WebhookTriggerDecorator.webhook_id2triggers.pop("dup_hook", None) + WebhookTriggerDecorator.webhook_id2triggers["dup_hook"] = {first} + try: + with pytest.raises(ValueError, match="sets_http_response_code=True"): + WebhookTriggerDecorator._add_trigger(second) # pylint: disable=protected-access + finally: + WebhookTriggerDecorator.webhook_id2triggers.pop("dup_hook", None) From 7949e9f9b2fe179895ca13d5bd4094188d84d0a2 Mon Sep 17 00:00:00 2001 From: Jay Newstrom Date: Tue, 12 May 2026 14:26:33 -0600 Subject: [PATCH 2/2] use CallResultHandlerDecorator for webhook response future Address PR review feedback: instead of adding a new result_future field to DispatchData, route the response future through the existing CallResultHandlerDecorator pipeline. --- custom_components/pyscript/decorator.py | 8 +-- custom_components/pyscript/decorator_abc.py | 11 ---- .../pyscript/decorators/webhook.py | 18 +++++- tests/test_decorator_manager.py | 60 ++----------------- 4 files changed, 24 insertions(+), 73 deletions(-) diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py index 8eb7898..98b792c 100644 --- a/custom_components/pyscript/decorator.py +++ b/custom_components/pyscript/decorator.py @@ -264,7 +264,6 @@ async def _call(self, data: DispatchData) -> None: # notify handlers with "None" for result_handler_dec in result_handlers: await result_handler_dec.handle_call_result(data, None) - data.set_result(None) return # Fire an event indicating that pyscript is running # Note: the event must have an entity_id for logbook to work correctly. @@ -280,9 +279,9 @@ async def _call(self, data: DispatchData) -> None: result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) for result_handler_dec in result_handlers: await result_handler_dec.handle_call_result(data, result) - data.set_result(result) except Exception as e: - data.set_result(None) + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, None) await self.handle_exception(e) async def dispatch(self, data: DispatchData) -> None: @@ -293,7 +292,8 @@ async def dispatch(self, data: DispatchData) -> None: for dec in decorators: if await dec.handle_dispatch(data) is False: self.logger.debug("Trigger not active due to %s", dec) - data.set_result(None) + for result_handler_dec in self.get_decorators(CallResultHandlerDecorator): + await result_handler_dec.handle_call_result(data, None) return action_ast_ctx = AstEval( diff --git a/custom_components/pyscript/decorator_abc.py b/custom_components/pyscript/decorator_abc.py index d24005b..4775317 100644 --- a/custom_components/pyscript/decorator_abc.py +++ b/custom_components/pyscript/decorator_abc.py @@ -3,7 +3,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -import asyncio from dataclasses import dataclass, field from enum import StrEnum import logging @@ -47,16 +46,6 @@ class DispatchData: call_ast_ctx: AstEval | None = field(default=None, kw_only=True) hass_context: Context | None = field(default=None, kw_only=True) - # When set, the dispatch pipeline resolves this future with the - # decorated function's return value. Resolved with None if the - # function is skipped (guard rejection) or raises. - result_future: asyncio.Future[Any] | None = field(default=None, kw_only=True) - - def set_result(self, value: Any) -> None: - """Resolve result_future with value if it is still pending.""" - if self.result_future is not None and not self.result_future.done(): - self.result_future.set_result(value) - class Decorator(ABC): """Generic decorator abstraction.""" diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py index ea72632..ec7e902 100644 --- a/custom_components/pyscript/decorators/webhook.py +++ b/custom_components/pyscript/decorators/webhook.py @@ -13,13 +13,15 @@ from homeassistant.components.webhook import SUPPORTED_METHODS from homeassistant.helpers import config_validation as cv -from ..decorator_abc import DispatchData, TriggerDecorator +from ..decorator_abc import CallResultHandlerDecorator, DispatchData, TriggerDecorator from .base import AutoKwargsDecorator, ExpressionDecorator _LOGGER = logging.getLogger(__name__) -class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): +class WebhookTriggerDecorator( + TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator, CallResultHandlerDecorator +): """Implementation for @webhook_trigger.""" name = "webhook_trigger" @@ -41,6 +43,7 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD local_only: bool methods: set[str] sets_http_response_code: bool + future: asyncio.Future[Any] | None = None webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {} @@ -75,10 +78,11 @@ async def _handler(hass, webhook_id, request): if not await trigger.check_expression_vars(trigger_args): continue future: asyncio.Future[Any] = hass.loop.create_future() + trigger.future = future if trigger.sets_http_response_code: response_future = future futures.append(future) - await trigger.dispatch(DispatchData(trigger_args, result_future=future)) + await trigger.dispatch(DispatchData(trigger_args)) if not futures: return None @@ -89,6 +93,14 @@ async def _handler(hass, webhook_id, request): return None return WebhookTriggerDecorator.coerce_response(response_future.result()) + async def handle_call_result(self, data: DispatchData, result: Any) -> None: + """Resolve the response future with the decorated function's return value.""" + if data.trigger is not self: + return + response_future = self.future + if response_future is not None and not response_future.done(): + response_future.set_result(result) + @staticmethod def coerce_response(value: Any) -> web.Response | None: """Convert a webhook function return value to an aiohttp Response.""" diff --git a/tests/test_decorator_manager.py b/tests/test_decorator_manager.py index f904871..575b96d 100644 --- a/tests/test_decorator_manager.py +++ b/tests/test_decorator_manager.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import logging from typing import ClassVar from unittest.mock import patch @@ -354,14 +353,12 @@ def make_dispatch_data( *, call_ast_ctx: DummyCallAstCtx | None = None, hass_context: Context | None = None, - result_future: asyncio.Future | None = None, ) -> DispatchData: """Build DispatchData from test doubles.""" return DispatchData( func_args, call_ast_ctx=call_ast_ctx, hass_context=hass_context, - result_future=result_future, ) @@ -607,59 +604,14 @@ async def test_function_decorator_manager_logs_call_exception(hass): @pytest.mark.asyncio -async def test_function_decorator_manager_result_future_success(hass): - """Successful calls should resolve result_future with the function's return value.""" - DecoratorManager.hass = hass - manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) - call_ast_ctx = DummyCallAstCtx(result=201) - future: asyncio.Future = hass.loop.create_future() - - with patch.object(Function, "store_hass_context"): - await call_function_manager( - manager, - make_dispatch_data( - {"arg1": 1}, - call_ast_ctx=call_ast_ctx, - hass_context=Context(id="call-parent"), - result_future=future, - ), - ) - await hass.async_block_till_done() - - assert future.done() - assert future.result() == 201 - - -@pytest.mark.asyncio -async def test_function_decorator_manager_result_future_cancel(hass): - """When a call handler cancels, result_future should resolve to None.""" - DecoratorManager.hass = hass - manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) - manager.add(make_cancel_call_handler()) - future: asyncio.Future = hass.loop.create_future() - - await call_function_manager( - manager, - make_dispatch_data( - {"arg1": 1}, - call_ast_ctx=DummyCallAstCtx(result="unused"), - hass_context=Context(id="call-parent"), - result_future=future, - ), - ) - - assert future.done() - assert future.result() is None - - -@pytest.mark.asyncio -async def test_function_decorator_manager_result_future_exception(hass): - """When the decorated function raises, result_future should resolve to None.""" +async def test_function_decorator_manager_exception_calls_result_handlers(hass): + """When the decorated function raises, result handlers should be notified with None.""" DecoratorManager.hass = hass ast_ctx = DummyAstCtx() manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + result_handler = make_recording_result_handler() + manager.add(result_handler) call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("boom")) - future: asyncio.Future = hass.loop.create_future() with patch.object(Function, "store_hass_context"): await call_function_manager( @@ -668,12 +620,10 @@ async def test_function_decorator_manager_result_future_exception(hass): {"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=Context(id="call-parent"), - result_future=future, ), ) - assert future.done() - assert future.result() is None + assert result_handler.results == [None] assert len(ast_ctx.logged_exceptions) == 1