diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/__init__.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/__init__.py index e711ce37..a956ef13 100644 --- a/fastapi_startkit/src/fastapi_startkit/broadcasting/__init__.py +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/__init__.py @@ -1,14 +1,17 @@ +from .auth import ChannelAuthRegistry from .channels import Channel, PrivateChannel, PresenceChannel from .event import BroadcastEvent, ShouldBroadcast +from .helpers import broadcast, channel from .provider import ReverbProvider -from .helpers import broadcast __all__ = [ "Channel", - "PrivateChannel", + "ChannelAuthRegistry", "PresenceChannel", + "PrivateChannel", "BroadcastEvent", - "ShouldBroadcast", "ReverbProvider", + "ShouldBroadcast", "broadcast", + "channel", ] diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/auth.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/auth.py new file mode 100644 index 00000000..6c4c6cc5 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/auth.py @@ -0,0 +1,140 @@ +"""Channel authorization registry for the Reverb broadcasting module. + +Usage:: + + from fastapi_startkit.broadcasting import channel + + @channel("orders.{order_id}") + async def authorize_orders_channel(user, order_id: int): + return user.id == order_id # True = authorized, False = denied + +The registry is held on the ``BroadcastManager`` instance that lives in the +container. ``ReverbProvider`` mounts the ``/broadcasting/auth`` endpoint +which calls :meth:`ChannelAuthRegistry.authorize` to verify subscriptions. +""" + +from __future__ import annotations + +import inspect +import re +from collections.abc import Callable +from typing import Any + + +def _pattern_to_regex(pattern: str) -> re.Pattern: + """Convert a ``{wildcard}``-style pattern to a compiled regex. + + Example: ``"orders.{order_id}"`` -> ``^orders\\.(?P[^.]+)$`` + """ + # Escape everything except our placeholder tokens + parts = re.split(r"(\{[^}]+\})", pattern) + regex_parts: list[str] = [] + for part in parts: + if part.startswith("{") and part.endswith("}"): + name = part[1:-1] + regex_parts.append(f"(?P<{name}>[^.]+)") + else: + regex_parts.append(re.escape(part)) + return re.compile("^" + "".join(regex_parts) + "$") + + +class ChannelAuthRegistry: + """Stores and evaluates channel authorization callbacks.""" + + def __init__(self) -> None: + # List of (compiled_pattern, original_pattern, callback) + self._rules: list[tuple[re.Pattern, str, Callable]] = [] + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register(self, pattern: str, callback: Callable) -> None: + """Register *callback* as the authorizer for channels matching *pattern*.""" + self._rules.append((_pattern_to_regex(pattern), pattern, callback)) + + def channel(self, pattern: str) -> Callable: + """Decorator factory — registers the decorated function for *pattern*. + + Example:: + + @registry.channel("orders.{order_id}") + async def authorize(user, order_id: int): + return user.id == order_id + """ + + def decorator(func: Callable) -> Callable: + self.register(pattern, func) + return func + + return decorator + + # ------------------------------------------------------------------ + # Lookup + # ------------------------------------------------------------------ + + def _find(self, channel_name: str) -> tuple[Callable, dict[str, str]] | None: + """Return ``(callback, wildcard_kwargs)`` for *channel_name*, or ``None``.""" + for compiled, _pattern, callback in self._rules: + m = compiled.match(channel_name) + if m: + return callback, m.groupdict() + return None + + # ------------------------------------------------------------------ + # Authorization + # ------------------------------------------------------------------ + + async def authorize(self, channel_name: str, user: Any) -> bool: + """Evaluate the authorization callback for *channel_name*. + + * Public channels (no ``private-`` / ``presence-`` prefix): always + allowed, no callback needed. + * Private / presence channels without a registered callback: **denied** + by default (fail-safe). + * Private / presence channels with a registered callback: the return + value of the callback determines access. + + Wildcard values extracted from the pattern are cast to the type hints + of the matching parameters before the callback is called. + """ + + is_private = channel_name.startswith("private-") or channel_name.startswith("presence-") + + result = self._find(channel_name) + + # Public channel with no rule → allow + if result is None and not is_private: + return True + + # Private/presence channel with no rule → deny + if result is None: + return False + + callback, raw_kwargs = result + + # Cast wildcard values to the declared type hints of the callback + hints = {} + try: + sig = inspect.signature(callback) + hints = { + k: p.annotation + for k, p in sig.parameters.items() + if p.annotation is not inspect.Parameter.empty and k != "user" + } + except (ValueError, TypeError): + pass + + cast_kwargs: dict[str, Any] = {} + for k, v in raw_kwargs.items(): + try: + cast_kwargs[k] = hints[k](v) if k in hints else v + except (ValueError, TypeError): + cast_kwargs[k] = v + + # Call the callback (sync or async) + result_value = callback(user, **cast_kwargs) + if inspect.isawaitable(result_value): + result_value = await result_value + + return bool(result_value) diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/channels.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/channels.py index ca6c35cc..6668667b 100644 --- a/fastapi_startkit/src/fastapi_startkit/broadcasting/channels.py +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/channels.py @@ -5,6 +5,14 @@ def __init__(self, name: str): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name!r})" + def __eq__(self, other: object) -> bool: + if not isinstance(other, Channel): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + class PrivateChannel(Channel): def __init__(self, name: str): diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/drivers/log_driver.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/drivers/log_driver.py index f2c7dace..2ddfcc82 100644 --- a/fastapi_startkit/src/fastapi_startkit/broadcasting/drivers/log_driver.py +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/drivers/log_driver.py @@ -9,3 +9,7 @@ async def broadcast(self, event) -> None: logger.info( f"[Broadcast] channel={channel.name} event={event.broadcast_as()} data={event.broadcast_with()}" ) + + async def broadcast_raw(self, channel: str, event_name: str, payload: dict) -> None: + """Log a raw channel/event/payload triple (used by :meth:`BroadcastManager.emit`).""" + logger.info(f"[Broadcast] channel={channel} event={event_name} data={payload}") diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/event.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/event.py index ef6c5f0a..57661a63 100644 --- a/fastapi_startkit/src/fastapi_startkit/broadcasting/event.py +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/event.py @@ -1,4 +1,12 @@ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from fastapi_startkit.application import app + +if TYPE_CHECKING: + pass class ShouldBroadcast(ABC): @@ -7,8 +15,58 @@ def broadcast_on(self) -> list: ... class BroadcastEvent(ShouldBroadcast): + """Base class for all broadcastable events. + + Subclasses must implement ``broadcast_on()`` returning a list of Channel + objects. At dispatch time the event is pushed to every channel in that + list using the configured broadcast driver. + + Attributes: + payload: Arbitrary data dict sent to subscribers. Defaults to an + empty dict; subclasses can populate it in ``__init__`` or by + overriding ``broadcast_with()``. + name: Event name seen by subscribers. Defaults to the class name at + emit time. Override to use a custom string. + """ + + payload: dict = {} + name: str | None = None + + # ------------------------------------------------------------------ + # Driver-layer helpers (used by ReverbDriver / LogDriver) + # ------------------------------------------------------------------ + def broadcast_as(self) -> str: - return self.__class__.__name__ + """Return the event name used on the wire. + + Uses ``self.name`` when set; otherwise falls back to the class name. + """ + return self.name if self.name is not None else self.__class__.__name__ def broadcast_with(self) -> dict: + """Return the payload dict. + + Returns ``self.payload`` when it is non-empty; otherwise serialises + all instance attributes (backward-compatible behaviour). + """ + if self.payload: + return dict(self.payload) return {k: v for k, v in self.__dict__.items()} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def emit(self) -> None: + """Dispatch this event via the BroadcastManager. + + Resolves the ``"broadcasting"`` service directly from the Application + container — no facade in the call chain. + + Equivalent to:: + + from fastapi_startkit.application import app + await app().make("broadcasting").dispatch(self) + """ + manager = app().make("broadcasting") + await manager.dispatch(self) diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/helpers.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/helpers.py index a8879fb9..d6e95e70 100644 --- a/fastapi_startkit/src/fastapi_startkit/broadcasting/helpers.py +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/helpers.py @@ -1,3 +1,5 @@ +from collections.abc import Callable + from ..application import app @@ -5,3 +7,33 @@ async def broadcast(event) -> None: """Broadcast an event using the default driver.""" manager = app().make("broadcasting") await manager.event(event) + + +def channel(pattern: str) -> Callable: + """Register a channel authorization callback. + + A decorator factory that registers the decorated function with the + ``ChannelAuthRegistry`` on the booted ``BroadcastManager``. + + Usage:: + + from fastapi_startkit.broadcasting import channel + + @channel("orders.{order_id}") + async def authorize_orders(user, order_id: int) -> bool: + return user is not None and user.id == order_id + + @channel("private-notifications") + async def authorize_notifications(user) -> bool: + return user is not None + + Pattern supports ``{wildcard}`` placeholders; wildcard values are cast to + the type-hints of the matching callback parameters before calling. + """ + + def decorator(func: Callable) -> Callable: + manager = app().make("broadcasting") + manager.channel_registry.register(pattern, func) + return func + + return decorator diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/manager.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/manager.py index dad67c56..196f7664 100644 --- a/fastapi_startkit/src/fastapi_startkit/broadcasting/manager.py +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/manager.py @@ -1,9 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from .auth import ChannelAuthRegistry + +if TYPE_CHECKING: + from .event import BroadcastEvent + + class BroadcastManager: - def __init__(self, config, server=None): + """Central broadcast manager. + + Holds the active driver, the Reverb server instance (if any), and the + :class:`~.auth.ChannelAuthRegistry` used by the ``/broadcasting/auth`` + endpoint. + + This is the service that the :class:`~fastapi_startkit.facades.Broadcast` + facade proxies to (container key: ``"broadcasting"``). + """ + + def __init__(self, config: dict, server: Any = None) -> None: self._config = config self._server = server + self.channel_registry = ChannelAuthRegistry() + + # ------------------------------------------------------------------ + # Driver resolution + # ------------------------------------------------------------------ - def driver(self, name=None): + def driver(self, name: str | None = None): name = name or self._config.get("default", "log") if name == "log": from .drivers.log_driver import LogDriver @@ -15,5 +40,55 @@ def driver(self, name=None): return ReverbDriver(self._server) raise ValueError(f"Unknown broadcast driver: {name}") - async def event(self, broadcast_event) -> None: + # ------------------------------------------------------------------ + # Dispatch helpers + # ------------------------------------------------------------------ + + async def event(self, broadcast_event: "BroadcastEvent") -> None: + """Broadcast *broadcast_event* using the default driver. + + .. deprecated:: + Prefer :meth:`dispatch` for new code. + """ await self.driver().broadcast(broadcast_event) + + async def dispatch(self, broadcast_event: "BroadcastEvent") -> None: + """Primary dispatch path. + + Sends *broadcast_event* to all channels returned by its + ``broadcast_on()`` method via the configured driver. + """ + await self.driver().broadcast(broadcast_event) + + async def emit(self, channel: str, event_name: str, payload: dict) -> None: + """Escape-hatch for direct / dynamic broadcasts. + + Sends *payload* as *event_name* to *channel* without needing a + ``BroadcastEvent`` subclass. + """ + driver_name = self._config.get("default", "log") + if driver_name == "reverb" and self._server is not None: + await self._server.broadcast_to_channel(channel, event_name, payload) + else: + # Log driver fallback + from .drivers.log_driver import LogDriver + + driver = LogDriver() + await driver.broadcast_raw(channel, event_name, payload) + + # ------------------------------------------------------------------ + # Channel authorization registry proxy + # ------------------------------------------------------------------ + + def channel(self, pattern: str): + """Decorator factory — register a channel authorization callback. + + Example:: + + Broadcast.channel("orders.{order_id}") + async def authorize(user, order_id: int): + return user.id == order_id + + Delegates to :attr:`channel_registry`. + """ + return self.channel_registry.channel(pattern) diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/provider.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/provider.py index 52bbe8dc..4ae4096e 100644 --- a/fastapi_startkit/src/fastapi_startkit/broadcasting/provider.py +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/provider.py @@ -1,10 +1,35 @@ +from __future__ import annotations + +import importlib +import importlib.util +import sys +from pathlib import Path + from ..providers import Provider -from .manager import BroadcastManager from .config import BroadcastingConfig +from .manager import BroadcastManager from .reverb.server import ReverbServer class ReverbProvider(Provider): + """Service provider that wires up the Reverb broadcasting stack. + + Registering ``ReverbProvider`` in the providers list is the **only** thing + a user needs to do to get broadcasting working. The provider: + + * Binds a :class:`~.manager.BroadcastManager` and a + :class:`~.reverb.server.ReverbServer` into the container. + * Mounts the Pusher-protocol WebSocket endpoint on the FastAPI app. + * Mounts a ``/broadcasting/auth`` HTTP endpoint for private/presence + channel authorisation (Laravel Echo handshake). + * Auto-loads ``routes/channels.py`` from the application base path so + ``@Broadcast.channel()`` decorators in that file are registered before + any subscription is attempted. + + Configuration is read from environment variables — see + :class:`~.config.BroadcastingConfig` for the full list. + """ + provider_key = "broadcasting" def register(self) -> None: @@ -18,4 +43,126 @@ def register(self) -> None: self.app.bind("reverb.server", server) def boot(self) -> None: - pass + import os + + # Publish the channels.py stub so developers can scaffold it + stub_src = os.path.abspath(os.path.join(os.path.dirname(__file__), "stubs/channels.py")) + self.publishes({stub_src: "routes/channels.py"}) + + self._load_channels_file() + self._mount_websocket() + self._mount_auth_endpoint() + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _load_channels_file(self) -> None: + """Auto-load ``routes/channels.py`` from the application base path. + + This is the file developers fill in with ``@channel()`` decorators + (imported from ``fastapi_startkit.broadcasting``). Importing it + registers the callbacks with the ``ChannelAuthRegistry`` on the + ``BroadcastManager``. + + The file is optional — its absence is silently ignored. + """ + channels_path = Path(self.app.base_path) / "routes" / "channels.py" + if not channels_path.exists(): + return + + module_name = "app.routes.channels" + spec = importlib.util.spec_from_file_location(module_name, str(channels_path)) + if spec is None or spec.loader is None: + return + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore[union-attr] + + def _mount_websocket(self) -> None: + """Mount the Reverb WebSocket endpoint on the FastAPI application. + + The endpoint path is ``/app/{app_key}``, matching the Pusher protocol + expected by Laravel Echo. The ``app_key`` is taken from the + ``reverb.connections.reverb.key`` config value (default: ``"local"``). + """ + try: + fastapi_app = self.app.fastapi + except RuntimeError: + # FastAPI not installed or app not yet booted — skip + return + + config = self.app.make("config").get("broadcasting") or {} + connections = config.get("connections", {}) if isinstance(config, dict) else {} + reverb_conn = connections.get("reverb", {}) if isinstance(connections, dict) else {} + app_key = reverb_conn.get("key", "local") if isinstance(reverb_conn, dict) else "local" + + server: ReverbServer = self.app.make("reverb.server") + + from starlette.routing import WebSocketRoute + from starlette.applications import Starlette + + async def ws_endpoint(websocket): + await server.handle(websocket) + + ws_app = Starlette(routes=[WebSocketRoute(f"/app/{app_key}", ws_endpoint)]) + fastapi_app.mount("", ws_app) + + def _mount_auth_endpoint(self) -> None: + """Mount the ``/broadcasting/auth`` HTTP endpoint. + + The endpoint handles the Laravel Echo authentication handshake for + private and presence channels. It: + + 1. Reads ``channel_name`` from the POST body (form or JSON). + 2. Resolves the current user via the container's ``auth`` binding (if + available), falling back to ``request.state.user``. + 3. Calls :meth:`~.auth.ChannelAuthRegistry.authorize` with the channel + name and the resolved user. + 4. Returns **200** on success or **403** on denial. + """ + try: + fastapi_app = self.app.fastapi + except RuntimeError: + return + + from fastapi import Request + from fastapi.responses import JSONResponse + + application = self.app + + @fastapi_app.post("/broadcasting/auth") + async def broadcasting_auth(request: Request): + # Resolve channel name from form data or JSON body + content_type = request.headers.get("content-type", "") + if "application/json" in content_type: + body = await request.json() + channel_name = body.get("channel_name", "") + else: + form = await request.form() + channel_name = form.get("channel_name", "") + + # Resolve authenticated user. + # Prefer request.state.user (set by auth middleware), then fall + # back to the container's "auth" service. + user = getattr(request.state, "user", None) + if user is None: + try: + auth_service = application.make("auth") + user_attr = getattr(auth_service, "user", None) + if callable(user_attr): + user = user_attr() + else: + user = user_attr + except Exception: + user = None + + # Authorise + manager: BroadcastManager = application.make("broadcasting") + allowed = await manager.channel_registry.authorize(channel_name, user) + + if not allowed: + return JSONResponse({"error": "Forbidden"}, status_code=403) + + return JSONResponse({"auth": True, "channel_data": {}}) diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/stubs/__init__.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/stubs/__init__.py new file mode 100644 index 00000000..3f7f7181 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/stubs/__init__.py @@ -0,0 +1 @@ +# Stub templates for the Reverb broadcasting module. diff --git a/fastapi_startkit/src/fastapi_startkit/broadcasting/stubs/channels.py b/fastapi_startkit/src/fastapi_startkit/broadcasting/stubs/channels.py new file mode 100644 index 00000000..f74d3150 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/broadcasting/stubs/channels.py @@ -0,0 +1,35 @@ +"""Channel authorization callbacks. + +Register authorization callbacks for private and presence channels here +using the ``@channel()`` decorator. + +Example:: + + from fastapi_startkit.broadcasting import channel + + @channel("orders.{order_id}") + async def authorize_orders_channel(user, order_id: int) -> bool: + # Return True to grant access, False to deny. + # `user` is the currently-authenticated user (resolved from the + # container's `auth` service or request.state.user). + return user is not None and user.id == order_id + + @channel("private-notifications") + async def authorize_notifications(user) -> bool: + return user is not None + +Supported channel types: + - ``Channel("name")`` — public, no auth required + - ``PrivateChannel("name")`` — auth checked via @channel + - ``PresenceChannel("name")`` — auth checked + member tracking (v2) + +This file is auto-loaded by ReverbProvider when it boots. +""" + +from fastapi_startkit.broadcasting import channel # noqa: F401 + +# Define your channel authorization callbacks below: +# +# @channel("private-{channel}") +# async def authorize_private_channel(user, channel: str) -> bool: +# return user is not None diff --git a/fastapi_startkit/src/fastapi_startkit/facades/Broadcast.py b/fastapi_startkit/src/fastapi_startkit/facades/Broadcast.py deleted file mode 100644 index 6cfaed34..00000000 --- a/fastapi_startkit/src/fastapi_startkit/facades/Broadcast.py +++ /dev/null @@ -1,5 +0,0 @@ -from .Facade import Facade - - -class Broadcast(metaclass=Facade): - key = "broadcast" diff --git a/fastapi_startkit/src/fastapi_startkit/facades/Broadcast.pyi b/fastapi_startkit/src/fastapi_startkit/facades/Broadcast.pyi deleted file mode 100644 index 7e2420fd..00000000 --- a/fastapi_startkit/src/fastapi_startkit/facades/Broadcast.pyi +++ /dev/null @@ -1,15 +0,0 @@ -from fastapi_startkit.broadcasting.channels import Channel -from fastapi_startkit.broadcasting.event import BroadcastEvent - -class Broadcast: - """Facade for broadcasting events over WebSocket channels.""" - - @staticmethod - async def event(event: BroadcastEvent) -> None: - """Broadcast an event to all its channels.""" - ... - - @staticmethod - def channel(name: str) -> Channel: - """Return a Channel instance for the given name.""" - ... diff --git a/fastapi_startkit/src/fastapi_startkit/facades/__init__.py b/fastapi_startkit/src/fastapi_startkit/facades/__init__.py index a8e5d064..6376090a 100644 --- a/fastapi_startkit/src/fastapi_startkit/facades/__init__.py +++ b/fastapi_startkit/src/fastapi_startkit/facades/__init__.py @@ -14,4 +14,3 @@ from .Queue import Queue from .Cache import Cache from .RateLimiter import RateLimiter -from .Broadcast import Broadcast diff --git a/fastapi_startkit/tests/broadcasting/test_channel_auth.py b/fastapi_startkit/tests/broadcasting/test_channel_auth.py new file mode 100644 index 00000000..aa94c474 --- /dev/null +++ b/fastapi_startkit/tests/broadcasting/test_channel_auth.py @@ -0,0 +1,176 @@ +"""Tests for the ChannelAuthRegistry and @Broadcast.channel() decorator (Task #148).""" + +import pytest + +from fastapi_startkit.broadcasting.auth import ChannelAuthRegistry + + +# --------------------------------------------------------------------------- +# Pattern-matching helpers (smoke tests) +# --------------------------------------------------------------------------- + + +def test_exact_channel_match(): + registry = ChannelAuthRegistry() + + @registry.channel("orders.42") + async def authorize(user): + return True + + # Should find the pattern and match + result = registry._find("orders.42") + assert result is not None + + +def test_wildcard_channel_match(): + registry = ChannelAuthRegistry() + + @registry.channel("orders.{order_id}") + async def authorize(user, order_id: int): + return True + + result = registry._find("orders.99") + assert result is not None + callback, kwargs = result + assert kwargs == {"order_id": "99"} + + +def test_no_match_returns_none(): + registry = ChannelAuthRegistry() + assert registry._find("unknown.channel") is None + + +# --------------------------------------------------------------------------- +# Authorization — public channels +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_public_channel_no_rule_allowed(): + registry = ChannelAuthRegistry() + # No callback registered; public channel = allow + assert await registry.authorize("orders.1", user=None) is True + + +@pytest.mark.asyncio +async def test_public_channel_with_allowing_rule(): + registry = ChannelAuthRegistry() + + @registry.channel("orders.{id}") + async def auth(user, id: str): + return True + + assert await registry.authorize("orders.1", user=None) is True + + +# --------------------------------------------------------------------------- +# Authorization — private channels +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_private_channel_no_rule_denied(): + registry = ChannelAuthRegistry() + assert await registry.authorize("private-orders.1", user=None) is False + + +@pytest.mark.asyncio +async def test_private_channel_with_allowing_callback(): + registry = ChannelAuthRegistry() + + @registry.channel("private-orders.{order_id}") + async def auth(user, order_id: int): + return user is not None + + class FakeUser: + pass + + assert await registry.authorize("private-orders.1", user=FakeUser()) is True + + +@pytest.mark.asyncio +async def test_private_channel_with_denying_callback(): + registry = ChannelAuthRegistry() + + @registry.channel("private-orders.{order_id}") + async def auth(user, order_id: int): + return False + + assert await registry.authorize("private-orders.1", user=object()) is False + + +# --------------------------------------------------------------------------- +# Authorization — presence channels +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_presence_channel_no_rule_denied(): + registry = ChannelAuthRegistry() + assert await registry.authorize("presence-room.42", user=None) is False + + +@pytest.mark.asyncio +async def test_presence_channel_with_allowing_callback(): + registry = ChannelAuthRegistry() + + @registry.channel("presence-room.{room_id}") + async def auth(user, room_id: str): + return True + + assert await registry.authorize("presence-room.42", user=object()) is True + + +# --------------------------------------------------------------------------- +# Wildcard type casting +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_wildcard_cast_to_int(): + registry = ChannelAuthRegistry() + received = {} + + @registry.channel("private-orders.{order_id}") + async def auth(user, order_id: int): + received["order_id"] = order_id + return True + + await registry.authorize("private-orders.123", user=object()) + assert received["order_id"] == 123 + assert isinstance(received["order_id"], int) + + +# --------------------------------------------------------------------------- +# Sync callbacks +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_sync_callback_supported(): + registry = ChannelAuthRegistry() + + @registry.channel("private-sync.{id}") + def auth(user, id: str): + return True + + assert await registry.authorize("private-sync.1", user=object()) is True + + +# --------------------------------------------------------------------------- +# BroadcastManager.channel() proxies to registry +# --------------------------------------------------------------------------- + + +def test_manager_channel_decorator_registers_callback(): + from fastapi_startkit.broadcasting.manager import BroadcastManager + + manager = BroadcastManager({"default": "log"}) + + @manager.channel("private-test.{id}") + async def auth(user, id: str): + return True + + # The callback should be findable in the registry + result = manager.channel_registry._find("private-test.99") + assert result is not None diff --git a/fastapi_startkit/tests/broadcasting/test_channels.py b/fastapi_startkit/tests/broadcasting/test_channels.py index 8c1aa4fd..440868c4 100644 --- a/fastapi_startkit/tests/broadcasting/test_channels.py +++ b/fastapi_startkit/tests/broadcasting/test_channels.py @@ -34,3 +34,46 @@ def test_presence_channel_is_channel(): def test_channel_name_unchanged_for_base(): ch = Channel("already-prefixed") assert ch.name == "already-prefixed" + + +# --------------------------------------------------------------------------- +# __eq__ and __hash__ +# --------------------------------------------------------------------------- + + +def test_channel_equality_same_name(): + assert Channel("orders.1") == Channel("orders.1") + + +def test_channel_inequality_different_name(): + assert Channel("orders.1") != Channel("orders.2") + + +def test_channel_equality_not_implemented_for_non_channel(): + ch = Channel("orders.1") + assert ch.__eq__("orders.1") is NotImplemented + + +def test_channel_hashable(): + ch = Channel("orders.1") + assert isinstance(hash(ch), int) + + +def test_channel_usable_in_set(): + channels = {Channel("orders.1"), Channel("orders.1"), Channel("orders.2")} + assert len(channels) == 2 + + +def test_channel_usable_as_dict_key(): + d = {Channel("orders.1"): "value"} + assert d[Channel("orders.1")] == "value" + + +def test_private_channel_equality(): + assert PrivateChannel("orders") == PrivateChannel("orders") + + +def test_private_and_base_channel_not_equal(): + # PrivateChannel("orders").name == "private-orders" + # Channel("orders").name == "orders" + assert PrivateChannel("orders") != Channel("orders") diff --git a/fastapi_startkit/tests/broadcasting/test_event_emit.py b/fastapi_startkit/tests/broadcasting/test_event_emit.py new file mode 100644 index 00000000..884fe841 --- /dev/null +++ b/fastapi_startkit/tests/broadcasting/test_event_emit.py @@ -0,0 +1,103 @@ +"""Tests for the new BroadcastEvent.payload / .name / .emit() API (Task #147).""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi_startkit.broadcasting.channels import Channel +from fastapi_startkit.broadcasting.event import BroadcastEvent + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class SimpleEvent(BroadcastEvent): + def broadcast_on(self): + return [Channel("test")] + + +class EventWithPayload(BroadcastEvent): + payload = {"order_id": 42} + + def broadcast_on(self): + return [Channel("orders")] + + +class EventWithName(BroadcastEvent): + name = "custom.event" + + def broadcast_on(self): + return [Channel("test")] + + +# --------------------------------------------------------------------------- +# payload attribute +# --------------------------------------------------------------------------- + + +def test_broadcast_event_default_payload_is_empty_dict(): + event = SimpleEvent() + assert event.payload == {} + + +def test_broadcast_event_payload_returned_by_broadcast_with(): + event = EventWithPayload() + assert event.broadcast_with() == {"order_id": 42} + + +def test_broadcast_event_broadcast_with_falls_back_to_instance_attrs_when_payload_empty(): + class EventWithAttr(BroadcastEvent): + def __init__(self): + self.foo = "bar" + + def broadcast_on(self): + return [Channel("x")] + + event = EventWithAttr() + # payload is empty (default), so instance attrs are returned + assert event.broadcast_with() == {"foo": "bar"} + + +# --------------------------------------------------------------------------- +# name attribute +# --------------------------------------------------------------------------- + + +def test_broadcast_event_default_name_is_none(): + event = SimpleEvent() + assert event.name is None + + +def test_broadcast_as_uses_name_when_set(): + event = EventWithName() + assert event.broadcast_as() == "custom.event" + + +def test_broadcast_as_falls_back_to_class_name_when_name_is_none(): + event = SimpleEvent() + assert event.broadcast_as() == "SimpleEvent" + + +# --------------------------------------------------------------------------- +# emit() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emit_calls_broadcast_dispatch(): + """emit() must be async and delegate to manager.dispatch(self) directly.""" + event = SimpleEvent() + + mock_manager = MagicMock() + mock_manager.dispatch = AsyncMock() + + mock_app = MagicMock() + mock_app.return_value.make.return_value = mock_manager + + # Patch the app() factory used inside emit() + with patch("fastapi_startkit.broadcasting.event.app", mock_app): + await event.emit() + + mock_app.return_value.make.assert_called_once_with("broadcasting") + mock_manager.dispatch.assert_called_once_with(event)