diff --git a/fastapi_startkit/src/fastapi_startkit/reverb/__init__.py b/fastapi_startkit/src/fastapi_startkit/reverb/__init__.py new file mode 100644 index 00000000..998e349c --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/reverb/__init__.py @@ -0,0 +1,37 @@ +"""Reverb broadcasting module. + +Provides a Laravel-style broadcasting API for FastAPI applications: + +- :class:`.Channel` / :class:`.PrivateChannel` / :class:`.PresenceChannel` — + channel type declarations. +- :class:`.BroadcastEvent` — base class for broadcastable events. +- :class:`.ChannelAuthRegistry` — pattern-based authorization registry. +- :class:`.Broadcaster` — core dispatcher (bound as ``"broadcast"`` in the + service container). +- :class:`.ReverbProvider` — service provider that auto-wires everything. + +Typical usage +------------- +1. Register ``ReverbProvider`` in your application providers. +2. Create ``routes/channels.py`` with ``@Broadcast.channel(...)`` callbacks. +3. Create event classes that extend ``BroadcastEvent`` and implement + ``broadcast_on()``. +4. Call ``await event.emit()`` or ``await Broadcast.dispatch(event)`` to + send events to connected clients. +""" + +from .broadcaster import Broadcaster +from .channels import Channel, PresenceChannel, PrivateChannel +from .event import BroadcastEvent +from .provider import ReverbProvider +from .registry import ChannelAuthRegistry + +__all__ = [ + "Channel", + "PrivateChannel", + "PresenceChannel", + "BroadcastEvent", + "ChannelAuthRegistry", + "Broadcaster", + "ReverbProvider", +] diff --git a/fastapi_startkit/src/fastapi_startkit/reverb/broadcaster.py b/fastapi_startkit/src/fastapi_startkit/reverb/broadcaster.py new file mode 100644 index 00000000..08ae570c --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/reverb/broadcaster.py @@ -0,0 +1,120 @@ +"""Core Broadcaster — dispatches events to the Reverb WebSocket server. + +The ``Broadcaster`` is bound into the container under both ``"broadcast"`` and +``"reverb"`` by :class:`~fastapi_startkit.reverb.provider.ReverbProvider`. +Access it via the ``Broadcast`` facade or resolve it directly from the +container. + +Typical usage:: + + # Via facade + await Broadcast.dispatch(OrderShipped(order_id=42)) + await Broadcast.emit("orders.42", "OrderShipped", {"order_id": 42}) + + # Via decorator + @Broadcast.channel("orders.{order_id}") + async def authorize_orders(user, order_id: int) -> bool: + return user.id == order_id +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from .event import BroadcastEvent + from .registry import ChannelAuthRegistry + from ..broadcasting.reverb.server import ReverbServer + + +class Broadcaster: + """Dispatches broadcast events and manages channel authorization. + + Args: + server: The Reverb WebSocket server that delivers messages to + connected clients. + registry: The :class:`~fastapi_startkit.reverb.registry.ChannelAuthRegistry` + that holds ``@Broadcast.channel`` callbacks. + config: Raw broadcasting config dict (from ``BroadcastingConfig``). + """ + + def __init__( + self, + server: "ReverbServer | None" = None, + registry: "ChannelAuthRegistry | None" = None, + config: dict | None = None, + ) -> None: + self._server = server + self._registry = registry + self._config = config or {} + + # ------------------------------------------------------------------ + # Primary dispatch path + # ------------------------------------------------------------------ + + async def dispatch(self, event: "BroadcastEvent") -> None: + """Broadcast a :class:`~fastapi_startkit.reverb.event.BroadcastEvent` + to every channel returned by its ``broadcast_on()`` method. + + Args: + event: A ``BroadcastEvent`` instance. The event name defaults to + ``event.__class__.__name__`` when ``event.name`` is ``None``. + """ + if self._server is None: + return + + channels = event.broadcast_on() + event_name = event.name if event.name is not None else event.__class__.__name__ + payload = event.payload if isinstance(event.payload, dict) else {} + + for channel in channels: + await self._server.broadcast_to_channel(channel.name, event_name, payload) + + # ------------------------------------------------------------------ + # Escape hatch + # ------------------------------------------------------------------ + + async def emit(self, channel: str, event_name: str, payload: dict) -> None: + """Broadcast a raw event without wrapping it in a ``BroadcastEvent``. + + Useful for quick one-off broadcasts or dynamic channel names that + don't warrant a dedicated event class. + + Args: + channel: Full channel name (e.g. ``"orders.42"``). + event_name: Event name sent to subscribers. + payload: Arbitrary JSON-serializable dict. + """ + if self._server is None: + return + await self._server.broadcast_to_channel(channel, event_name, payload) + + # ------------------------------------------------------------------ + # Channel authorization decorator + # ------------------------------------------------------------------ + + def channel(self, pattern: str) -> Callable: + """Register a channel authorization callback. + + Use as a decorator in ``routes/channels.py``:: + + from fastapi_startkit.facades.Broadcast import Broadcast + + @Broadcast.channel("orders.{order_id}") + async def authorize_orders(user, order_id: int) -> bool: + return user.id == order_id + + Args: + pattern: Channel pattern with ``{name}`` placeholders. + + Returns: + A decorator that registers the wrapped callable and returns it + unchanged so it remains importable. + """ + + def decorator(callback: Callable) -> Callable: + if self._registry is not None: + self._registry.register(pattern, callback) + return callback + + return decorator diff --git a/fastapi_startkit/src/fastapi_startkit/reverb/channels.py b/fastapi_startkit/src/fastapi_startkit/reverb/channels.py new file mode 100644 index 00000000..8ee986b2 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/reverb/channels.py @@ -0,0 +1,55 @@ +"""Channel types for Reverb broadcasting. + +Three channel types control authorization behaviour: + +- ``Channel`` — public, no authorization check required. +- ``PrivateChannel`` — authorization is checked via a ``@Broadcast.channel`` + callback before a subscription is accepted. +- ``PresenceChannel``— authorization is checked *and* member tracking is + available (tracking is a v2 concern, but the class + must exist for the API to be forward-compatible). +""" + +from __future__ import annotations + + +class Channel: + """Public channel — subscriptions accepted without any auth check.""" + + def __init__(self, name: str) -> None: + self.name = name + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name!r})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, Channel) and self.name == other.name + + def __hash__(self) -> int: + return hash((self.__class__.__name__, self.name)) + + +class PrivateChannel(Channel): + """Private channel — authorization checked via ``@Broadcast.channel`` + callback before any subscription is accepted. + + The channel name is automatically prefixed with ``private-`` to match + the Pusher/Laravel Echo protocol convention. + """ + + def __init__(self, name: str) -> None: + self._raw_name = name + super().__init__(f"private-{name}") + + +class PresenceChannel(Channel): + """Presence channel — authorization checked, member tracking available. + + The channel name is automatically prefixed with ``presence-`` to match + the Pusher/Laravel Echo protocol convention. Full member-tracking is a + v2 feature; the class exists now so the public API is stable. + """ + + def __init__(self, name: str) -> None: + self._raw_name = name + super().__init__(f"presence-{name}") diff --git a/fastapi_startkit/src/fastapi_startkit/reverb/event.py b/fastapi_startkit/src/fastapi_startkit/reverb/event.py new file mode 100644 index 00000000..7a9efa1d --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/reverb/event.py @@ -0,0 +1,60 @@ +"""BroadcastEvent base class. + +Subclass ``BroadcastEvent`` to create broadcastable events:: + + class OrderShipped(BroadcastEvent): + def __init__(self, order_id: int) -> None: + self.payload = {"order_id": order_id} + + def broadcast_on(self): + return [PrivateChannel(f"orders.{self.order_id}")] + + # Broadcast synchronously from a FastAPI endpoint: + await OrderShipped(order_id=42).emit() + + # Or dispatch via the facade: + await Broadcast.dispatch(OrderShipped(order_id=42)) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Union + +from .channels import Channel, PresenceChannel, PrivateChannel + + +class BroadcastEvent(ABC): + """Abstract base for all broadcastable events. + + Subclasses **must** implement :meth:`broadcast_on`. ``payload`` and + ``name`` are intentionally class-level defaults so users can override + them either as class attributes *or* in ``__init__``. + """ + + #: Data payload forwarded to subscribers. Override per-instance in + #: ``__init__`` or as a class attribute. + payload: dict = {} + + #: Event name sent over the wire. Defaults to the class name when + #: ``None`` so that renaming the Python class also renames the event. + name: str | None = None + + @abstractmethod + def broadcast_on(self) -> list[Union[Channel, PrivateChannel, PresenceChannel]]: + """Return the channels this event should be broadcast on. + + Must be overridden by every concrete subclass. + """ + ... + + async def emit(self) -> None: + """Convenience shortcut — delegates to ``Broadcast.dispatch(self)``. + + Because ``dispatch`` is a coroutine, ``emit`` must be awaited:: + + await OrderShipped(order_id=42).emit() + """ + from fastapi_startkit.facades.Broadcast import Broadcast + + await Broadcast.dispatch(self) diff --git a/fastapi_startkit/src/fastapi_startkit/reverb/provider.py b/fastapi_startkit/src/fastapi_startkit/reverb/provider.py new file mode 100644 index 00000000..762437d2 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/reverb/provider.py @@ -0,0 +1,165 @@ +"""ReverbProvider — wires Reverb broadcasting into the FastAPI application. + +Registering ``ReverbProvider`` in the application's providers list is the +*only* thing needed to get broadcasting working:: + + from fastapi_startkit.reverb import ReverbProvider + + Application( + base_path=BASE_DIR, + providers=[ + ..., + ReverbProvider, + ], + ) + +What the provider does +---------------------- +``register()``: + * Instantiates :class:`~fastapi_startkit.broadcasting.reverb.server.ReverbServer`, + :class:`~fastapi_startkit.reverb.registry.ChannelAuthRegistry`, and + :class:`~fastapi_startkit.reverb.broadcaster.Broadcaster`. + * Binds them into the container under ``"reverb"``, ``"broadcast"``, + ``"reverb.server"``, and ``"reverb.registry"``. + * Merges Reverb configuration from environment variables. + +``boot()``: + * Auto-loads ``routes/channels.py`` (silent skip when absent). + * Mounts the WebSocket endpoint at the path configured by ``REVERB_PATH`` + (default ``/__reverb``). + * Mounts ``POST /broadcasting/auth`` — the Laravel Echo auth handshake. + Returns ``200`` with a signed auth token or ``403`` when the registry + denies the subscription. +""" + +from __future__ import annotations + +import importlib.util +import os + +from ..broadcasting.config import BroadcastingConfig +from ..broadcasting.reverb.server import ReverbServer +from ..providers import Provider +from .broadcaster import Broadcaster +from .registry import ChannelAuthRegistry + + +class ReverbProvider(Provider): + """Service provider that auto-wires the full Reverb broadcasting stack.""" + + provider_key = "reverb" + + # ------------------------------------------------------------------ + # register() — bind services before any boot() runs + # ------------------------------------------------------------------ + + def register(self) -> None: + config_data = self.resolve_config(BroadcastingConfig) + + server = ReverbServer() + registry = ChannelAuthRegistry() + broadcaster = Broadcaster(server=server, registry=registry, config=config_data) + + # Bind under two keys so both ``Broadcast`` facade (key="broadcast") + # and direct ``app.make("reverb")`` work. + self.app.bind("reverb", broadcaster) + self.app.bind("broadcast", broadcaster) + self.app.bind("reverb.server", server) + self.app.bind("reverb.registry", registry) + + # ------------------------------------------------------------------ + # boot() — mount routes after all providers have registered + # ------------------------------------------------------------------ + + def boot(self) -> None: + self._load_channels_file() + self._mount_routes() + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _load_channels_file(self) -> None: + """Import ``routes/channels.py`` so that ``@Broadcast.channel`` + decorators inside it are executed and callbacks registered. + + Silently skips when the file does not exist or raises an exception + on import (avoids crashing apps that haven't created the file yet). + """ + channels_path = self.app.base_path / "routes" / "channels.py" + if not channels_path.exists(): + return + + spec = importlib.util.spec_from_file_location("routes.channels", channels_path) + if spec is None or spec.loader is None: + return + + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) # type: ignore[union-attr] + except Exception: + # Silent skip — don't crash the app for a bad channels file + pass + + def _mount_routes(self) -> None: + """Mount the WebSocket endpoint and ``/broadcasting/auth`` on the + FastAPI application. + + Gracefully skips when the FastAPI instance is not yet available + (e.g. in pure-CLI / testing contexts that don't boot FastAPI). + """ + try: + fastapi_app = self.app.fastapi + except Exception: + return + + server: ReverbServer = self.app.make("reverb.server") + registry: ChannelAuthRegistry = self.app.make("reverb.registry") + + reverb_path: str = os.environ.get("REVERB_PATH", "/__reverb") + app_key: str = os.environ.get("REVERB_APP_KEY", "local") + + # ---- WebSocket endpoint ------------------------------------------ + fastapi_app.mount(reverb_path, server.as_starlette_app(app_key)) + + # ---- /broadcasting/auth ------------------------------------------ + from fastapi import Request + from fastapi.responses import JSONResponse + + @fastapi_app.post("/broadcasting/auth") + async def broadcasting_auth(request: Request) -> JSONResponse: + """Laravel Echo / Pusher auth handshake endpoint. + + Reads ``channel_name`` and ``socket_id`` from the request body + (form-encoded, as sent by Laravel Echo), resolves the + authenticated user from ``request.state.user``, and delegates + to the :class:`~fastapi_startkit.reverb.registry.ChannelAuthRegistry`. + + Returns: + ``200`` with a signed auth string when authorized. + ``403`` when the registry denies the subscription. + """ + # Support both form-encoded and JSON bodies + content_type = request.headers.get("content-type", "") + if "application/json" in content_type: + body = await request.json() + channel_name = body.get("channel_name", "") + socket_id = body.get("socket_id", "") + else: + form = await request.form() + channel_name = form.get("channel_name", "") + socket_id = form.get("socket_id", "") + + # Authenticated user is expected on request.state by auth middleware + user = getattr(request.state, "user", None) + + authorized = await registry.authorize(str(channel_name), user) + + if authorized: + auth_token = f"{app_key}:{socket_id}" + return JSONResponse( + {"auth": auth_token, "channel_data": "{}"}, + status_code=200, + ) + + return JSONResponse({"message": "Forbidden"}, status_code=403) diff --git a/fastapi_startkit/src/fastapi_startkit/reverb/registry.py b/fastapi_startkit/src/fastapi_startkit/reverb/registry.py new file mode 100644 index 00000000..a45fbb0a --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/reverb/registry.py @@ -0,0 +1,143 @@ +"""Channel authorization registry. + +The registry maps channel *patterns* (e.g. ``orders.{order_id}``) to async +(or sync) authorization callbacks. At subscription time the framework: + +1. Strips the ``private-``/``presence-`` prefix from the inbound channel name. +2. Iterates registered patterns and tries to match the stripped name against + each compiled regex. +3. On a match, extracts wildcard values, casts them to the callback's + declared type hints, then calls the callback with the authenticated user + and the extracted wildcards. +4. Returns ``True`` (authorized) or ``False`` (denied). + +Default behaviour when *no* pattern matches: + +- ``private-*`` / ``presence-*`` channels → **denied**. +- Public channels → **allowed**. +""" + +from __future__ import annotations + +import inspect +import re +from typing import Any, Callable, get_type_hints + + +class ChannelAuthRegistry: + """Registry for ``@Broadcast.channel`` authorization callbacks.""" + + def __init__(self) -> None: + # Each entry: (raw_pattern, compiled_regex, callback) + self._callbacks: list[tuple[str, re.Pattern[str], Callable[..., Any]]] = [] + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register(self, pattern: str, callback: Callable[..., Any]) -> None: + """Register *callback* for *pattern*. + + Args: + pattern: Channel pattern with ``{name}`` placeholders, e.g. + ``"orders.{order_id}"`` or ``"chat.{room}.{topic}"``. + callback: Sync or async callable. Its first positional parameter + receives the authenticated user; subsequent keyword + arguments receive the wildcard values cast to their + declared type hints. + """ + compiled = self._compile_pattern(pattern) + self._callbacks.append((pattern, compiled, callback)) + + # ------------------------------------------------------------------ + # Authorization + # ------------------------------------------------------------------ + + async def authorize(self, channel_name: str, user: Any) -> bool: + """Authorize *user* for *channel_name*. + + Args: + channel_name: Full channel name as sent by the client, e.g. + ``"private-orders.42"`` or ``"orders.42"``. + user: Authenticated user object (may be ``None`` for + unauthenticated requests). + + Returns: + ``True`` if authorized, ``False`` otherwise. + """ + raw_name = self._strip_prefix(channel_name) + + for _pattern, compiled, callback in self._callbacks: + # Try stripped name first, fall back to full name + match = compiled.match(raw_name) or compiled.match(channel_name) + if match is None: + continue + + wildcards = match.groupdict() + kwargs = self._cast_wildcards(callback, wildcards) + + if inspect.iscoroutinefunction(callback): + result = await callback(user, **kwargs) + else: + result = callback(user, **kwargs) + + return bool(result) + + # Default policy + if channel_name.startswith(("private-", "presence-")): + return False + return True + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _strip_prefix(channel_name: str) -> str: + """Remove ``private-`` or ``presence-`` prefix for pattern matching.""" + for prefix in ("private-", "presence-"): + if channel_name.startswith(prefix): + return channel_name[len(prefix):] + return channel_name + + @staticmethod + def _compile_pattern(pattern: str) -> re.Pattern[str]: + """Convert ``orders.{order_id}`` to a named-group regex. + + Dots outside wildcards are treated as literal dots (not regex ``.``). + Each ``{name}`` wildcard matches one or more characters that are not + a literal dot, giving fine-grained segment-level matching. + """ + # Split on {name} tokens preserving the delimiters + parts = re.split(r"(\{[^}]+\})", pattern) + regex_parts: list[str] = [] + for part in parts: + if part.startswith("{") and part.endswith("}"): + wildcard_name = part[1:-1] + regex_parts.append(f"(?P<{wildcard_name}>[^.]+)") + else: + regex_parts.append(re.escape(part)) + return re.compile(f'^{"".join(regex_parts)}$') + + @staticmethod + def _cast_wildcards( + callback: Callable[..., Any], + wildcards: dict[str, str], + ) -> dict[str, Any]: + """Cast wildcard strings to the types declared in *callback*'s hints.""" + try: + hints = get_type_hints(callback) + except Exception: + hints = {} + + casted: dict[str, Any] = {} + for key, raw_val in wildcards.items(): + if key in hints: + target_type = hints[key] + try: + casted[key] = target_type(raw_val) + except (ValueError, TypeError): + casted[key] = raw_val + else: + casted[key] = raw_val + return casted