diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py index 33af0d3..98b792c 100644 --- a/custom_components/pyscript/decorator.py +++ b/custom_components/pyscript/decorator.py @@ -280,6 +280,8 @@ async def _call(self, data: DispatchData) -> None: for result_handler_dec in result_handlers: await result_handler_dec.handle_call_result(data, result) except Exception as e: + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, None) await self.handle_exception(e) async def dispatch(self, data: DispatchData) -> None: @@ -290,6 +292,8 @@ async def dispatch(self, data: DispatchData) -> None: for dec in decorators: if await dec.handle_dispatch(data) is False: self.logger.debug("Trigger not active due to %s", dec) + for result_handler_dec in self.get_decorators(CallResultHandlerDecorator): + await result_handler_dec.handle_call_result(data, None) return action_ast_ctx = AstEval( diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py index 3db0a09..ec7e902 100644 --- a/custom_components/pyscript/decorators/webhook.py +++ b/custom_components/pyscript/decorators/webhook.py @@ -2,23 +2,26 @@ from __future__ import annotations +import asyncio import logging -from typing import ClassVar +from typing import Any, ClassVar -from aiohttp import hdrs +from aiohttp import hdrs, web import voluptuous as vol from homeassistant.components import webhook from homeassistant.components.webhook import SUPPORTED_METHODS from homeassistant.helpers import config_validation as cv -from ..decorator_abc import DispatchData, TriggerDecorator +from ..decorator_abc import CallResultHandlerDecorator, DispatchData, TriggerDecorator from .base import AutoKwargsDecorator, ExpressionDecorator _LOGGER = logging.getLogger(__name__) -class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): +class WebhookTriggerDecorator( + TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator, CallResultHandlerDecorator +): """Implementation for @webhook_trigger.""" name = "webhook_trigger" @@ -32,12 +35,15 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD { vol.Optional("local_only", default=True): cv.boolean, vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]), + vol.Optional("sets_http_response_code", default=False): cv.boolean, } ) webhook_id: str local_only: bool methods: set[str] + sets_http_response_code: bool + future: asyncio.Future[Any] | None = None webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {} @@ -50,7 +56,7 @@ async def validate(self): self.create_expression(self.args[1]) @staticmethod - async def _handler(_hass, webhook_id, request): + async def _handler(hass, webhook_id, request): func_args = { "trigger_type": "webhook", "webhook_id": webhook_id, @@ -64,17 +70,68 @@ async def _handler(_hass, webhook_id, request): payload_multidict = await request.post() func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + response_future: asyncio.Future[Any] | None = None + futures: list[asyncio.Future[Any]] = [] for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy(): trigger_args = func_args.copy() if trigger.has_expression(): if not await trigger.check_expression_vars(trigger_args): continue + future: asyncio.Future[Any] = hass.loop.create_future() + trigger.future = future + if trigger.sets_http_response_code: + response_future = future + futures.append(future) await trigger.dispatch(DispatchData(trigger_args)) + if not futures: + return None + + await asyncio.gather(*futures, return_exceptions=True) + + if response_future is None: + return None + return WebhookTriggerDecorator.coerce_response(response_future.result()) + + async def handle_call_result(self, data: DispatchData, result: Any) -> None: + """Resolve the response future with the decorated function's return value.""" + if data.trigger is not self: + return + response_future = self.future + if response_future is not None and not response_future.done(): + response_future.set_result(result) + + @staticmethod + def coerce_response(value: Any) -> web.Response | None: + """Convert a webhook function return value to an aiohttp Response.""" + if value is None: + return None + if isinstance(value, web.Response): + return value + # bool is a subclass of int; reject it so True/False don't become 1/0 status codes. + if isinstance(value, int) and not isinstance(value, bool): + return web.Response(status=value) + _LOGGER.warning( + "webhook function returned unsupported type %s; expected int status code or aiohttp.web.Response", + type(value).__name__, + ) + return None + @staticmethod def _add_trigger(trigger: WebhookTriggerDecorator) -> None: webhook_id = trigger.webhook_id - if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers: + existing = WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id) + if ( + trigger.sets_http_response_code + and existing is not None + and any(t.sets_http_response_code for t in existing) + ): + raise ValueError( + f"webhook_id '{webhook_id}' already has a @webhook_trigger with " + f"sets_http_response_code=True; only one is allowed" + ) + + if existing is None: webhook.async_register( trigger.dm.hass, "pyscript", # DOMAIN diff --git a/custom_components/pyscript/stubs/pyscript_builtins.py b/custom_components/pyscript/stubs/pyscript_builtins.py index ea75580..61d3517 100644 --- a/custom_components/pyscript/stubs/pyscript_builtins.py +++ b/custom_components/pyscript/stubs/pyscript_builtins.py @@ -127,6 +127,7 @@ def webhook_trigger( str_expr: str | None = None, local_only: bool = True, methods: set[SUPPORTED_METHODS] | list[SUPPORTED_METHODS] = {"POST", "PUT"}, + sets_http_response_code: bool = False, kwargs: dict | None = None, ) -> Callable[..., Any]: """Trigger when a request is made to a webhook endpoint. @@ -136,6 +137,7 @@ def webhook_trigger( str_expr: Optional expression evaluated against ``trigger_type``, ``webhook_id``, ``request``, and ``payload``. local_only: If False, allow requests from anywhere on the internet. methods: HTTP methods to allow. + sets_http_response_code: If True, the function's return value drives the HTTP response (``int`` status code or ``aiohttp.web.Response``); at most one trigger per ``webhook_id`` may set this. kwargs: Extra keyword arguments merged into each invocation. Trigger kwargs include ``trigger_type="webhook"``, ``webhook_id``, the parsed payload fields, and ``request`` (the underlying ``aiohttp.web.Request``). diff --git a/docs/reference.rst b/docs/reference.rst index 3b7c587..2781215 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -915,6 +915,18 @@ To validate an HMAC signature on incoming requests, declare ``request`` in the f return log.info(f"verified webhook: {payload}") +To control the HTTP response sent back to the webhook caller, opt in by passing ``sets_http_response_code=True``. The flagged function's return value then drives the response: ``None`` produces a ``200 OK``, an ``int`` sends back a response with that status code, and an ``aiohttp.web.Response`` allows full control over the body and headers. Return values from triggers without the flag are ignored. For example: + +.. code:: python + + @webhook_trigger("myid", sets_http_response_code=True) + def webhook_check(payload): + if "token" not in payload: + return 401 + return 204 + +At most one ``@webhook_trigger`` per ``webhook_id`` may set ``sets_http_response_code=True``; declaring more than one is an error at setup time. The webhook handler waits for all decorated function(s) for the ``webhook_id`` to finish before responding, so use ``task.create()`` to fire-and-forget any long-running work. + NOTE: A webhook_id can only be used by either a built-in Home Assistant automation or pyscript, but not both. Trying to use the same webhook_id in both will result in an error. @state_active diff --git a/tests/test_decorator_manager.py b/tests/test_decorator_manager.py index 45c6f89..575b96d 100644 --- a/tests/test_decorator_manager.py +++ b/tests/test_decorator_manager.py @@ -355,7 +355,11 @@ def make_dispatch_data( hass_context: Context | None = None, ) -> DispatchData: """Build DispatchData from test doubles.""" - return DispatchData(func_args, call_ast_ctx=call_ast_ctx, hass_context=hass_context) + return DispatchData( + func_args, + call_ast_ctx=call_ast_ctx, + hass_context=hass_context, + ) def setup_global_context_function_hass(hass: HomeAssistant, config_data: dict | None = None) -> None: @@ -599,6 +603,30 @@ async def test_function_decorator_manager_logs_call_exception(hass): assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed" +@pytest.mark.asyncio +async def test_function_decorator_manager_exception_calls_result_handlers(hass): + """When the decorated function raises, result handlers should be notified with None.""" + DecoratorManager.hass = hass + ast_ctx = DummyAstCtx() + manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + result_handler = make_recording_result_handler() + manager.add(result_handler) + call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("boom")) + + with patch.object(Function, "store_hass_context"): + await call_function_manager( + manager, + make_dispatch_data( + {"arg1": 1}, + call_ast_ctx=call_ast_ctx, + hass_context=Context(id="call-parent"), + ), + ) + + assert result_handler.results == [None] + assert len(ast_ctx.logged_exceptions) == 1 + + def test_decorator_registry_register_requires_name(): """Registry should reject decorators without a declared name.""" diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 12224d4..a214787 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -3,12 +3,15 @@ from ast import literal_eval import asyncio from datetime import datetime as dt -from unittest.mock import mock_open, patch +from http import HTTPStatus +from unittest.mock import Mock, mock_open, patch +from aiohttp import web import pytest from custom_components.pyscript import trigger from custom_components.pyscript.const import DOMAIN +from custom_components.pyscript.decorators.webhook import WebhookTriggerDecorator from custom_components.pyscript.function import Function from homeassistant.components import webhook from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED @@ -256,3 +259,169 @@ def webhook_test(payload, request): await webhook.async_handle_webhook(hass, "test_req_hook", request) assert literal_eval(await wait_until_done(notify_q)) == ["abc123", "POST", {"hello": "world"}] + + +def _post_webhook_request() -> MockRequest: + """Build a MockRequest representing a webhook POST with form data.""" + return MockRequest( + content=b"", + headers={}, + method="POST", + query_string="", + mock_source="test", + remote="127.0.0.1", + ) + + +def test_webhook_coerce_response_none(): + """A None return should fall through to the HA default response.""" + assert WebhookTriggerDecorator.coerce_response(None) is None + + +def test_webhook_coerce_response_int(): + """Int returns should produce an aiohttp Response with that status.""" + response = WebhookTriggerDecorator.coerce_response(HTTPStatus.CREATED.value) + assert isinstance(response, web.Response) + assert response.status == HTTPStatus.CREATED + + +def test_webhook_coerce_response_passthrough(): + """An aiohttp Response should be returned unchanged.""" + custom = web.Response(status=HTTPStatus.ACCEPTED, body=b"queued") + assert WebhookTriggerDecorator.coerce_response(custom) is custom + + +def test_webhook_coerce_response_bool_warns(caplog): + """Bool returns should be rejected so True/False don't masquerade as 1/0.""" + assert WebhookTriggerDecorator.coerce_response(True) is None + assert "unsupported type bool" in caplog.text + + +def test_webhook_coerce_response_unsupported_warns(caplog): + """Other return types should warn and fall through.""" + assert WebhookTriggerDecorator.coerce_response("ok") is None + assert "unsupported type str" in caplog.text + + +@pytest.mark.asyncio +async def test_webhook_function_returns_status_code(hass): + """A flagged webhook function returning an int should set the HTTP status.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("status_hook", sets_http_response_code=True) +def func_status(payload): + return 201 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "status_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.CREATED + + +@pytest.mark.asyncio +async def test_webhook_function_default_response(hass): + """A pyscript webhook function returning None should produce a 200 OK.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("default_hook") +def func_default(payload): + pass +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "default_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_webhook_unflagged_return_ignored(hass): + """A return value from a trigger without sets_http_response_code is ignored.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("unflagged_hook") +def func_unflagged(payload): + return 418 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "unflagged_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_webhook_only_flagged_trigger_controls_response(hass): + """When multiple triggers share an id, only the flagged one drives the response.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("multi_hook") +def func_silent(payload): + return 500 + +@webhook_trigger("multi_hook", sets_http_response_code=True) +def func_loud(payload): + return 418 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "multi_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == 418 + + +@pytest.mark.asyncio +async def test_webhook_flag_is_per_webhook_id(hass): + """Different webhook_ids may each have their own sets_http_response_code=True trigger.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("hook_a", sets_http_response_code=True) +def func_a(payload): + return 201 + +@webhook_trigger("hook_b", sets_http_response_code=True) +def func_b(payload): + return 202 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response_a = await webhook.async_handle_webhook(hass, "hook_a", _post_webhook_request()) + response_b = await webhook.async_handle_webhook(hass, "hook_b", _post_webhook_request()) + await hass.async_block_till_done() + assert response_a.status == HTTPStatus.CREATED + assert response_b.status == HTTPStatus.ACCEPTED + + +def test_webhook_multiple_flagged_triggers_fails_at_setup(): + """A second flagged trigger for the same webhook_id should be rejected at start.""" + first = Mock(webhook_id="dup_hook", sets_http_response_code=True) + second = Mock(webhook_id="dup_hook", sets_http_response_code=True) + + WebhookTriggerDecorator.webhook_id2triggers.pop("dup_hook", None) + WebhookTriggerDecorator.webhook_id2triggers["dup_hook"] = {first} + try: + with pytest.raises(ValueError, match="sets_http_response_code=True"): + WebhookTriggerDecorator._add_trigger(second) # pylint: disable=protected-access + finally: + WebhookTriggerDecorator.webhook_id2triggers.pop("dup_hook", None)