Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from .auth import ChannelAuthRegistry
from .channels import Channel, PrivateChannel, PresenceChannel
from .event import BroadcastEvent, ShouldBroadcast
from .helpers import broadcast, channel
from .provider import ReverbProvider
from .helpers import broadcast

__all__ = [
"Channel",
"PrivateChannel",
"ChannelAuthRegistry",
"PresenceChannel",
"PrivateChannel",
"BroadcastEvent",
"ShouldBroadcast",
"ReverbProvider",
"ShouldBroadcast",
"broadcast",
"channel",
]
140 changes: 140 additions & 0 deletions fastapi_startkit/src/fastapi_startkit/broadcasting/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Channel authorization registry for the Reverb broadcasting module.

Usage::

from fastapi_startkit.broadcasting import channel

@channel("orders.{order_id}")
async def authorize_orders_channel(user, order_id: int):
return user.id == order_id # True = authorized, False = denied

The registry is held on the ``BroadcastManager`` instance that lives in the
container. ``ReverbProvider`` mounts the ``/broadcasting/auth`` endpoint
which calls :meth:`ChannelAuthRegistry.authorize` to verify subscriptions.
"""

from __future__ import annotations

import inspect
import re
from collections.abc import Callable
from typing import Any


def _pattern_to_regex(pattern: str) -> re.Pattern:
"""Convert a ``{wildcard}``-style pattern to a compiled regex.

Example: ``"orders.{order_id}"`` -> ``^orders\\.(?P<order_id>[^.]+)$``
"""
# Escape everything except our placeholder tokens
parts = re.split(r"(\{[^}]+\})", pattern)
regex_parts: list[str] = []
for part in parts:
if part.startswith("{") and part.endswith("}"):
name = part[1:-1]
regex_parts.append(f"(?P<{name}>[^.]+)")
else:
regex_parts.append(re.escape(part))
return re.compile("^" + "".join(regex_parts) + "$")


class ChannelAuthRegistry:
"""Stores and evaluates channel authorization callbacks."""

def __init__(self) -> None:
# List of (compiled_pattern, original_pattern, callback)
self._rules: list[tuple[re.Pattern, str, Callable]] = []

# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------

def register(self, pattern: str, callback: Callable) -> None:
"""Register *callback* as the authorizer for channels matching *pattern*."""
self._rules.append((_pattern_to_regex(pattern), pattern, callback))

def channel(self, pattern: str) -> Callable:
"""Decorator factory — registers the decorated function for *pattern*.

Example::

@registry.channel("orders.{order_id}")
async def authorize(user, order_id: int):
return user.id == order_id
"""

def decorator(func: Callable) -> Callable:
self.register(pattern, func)
return func

return decorator

# ------------------------------------------------------------------
# Lookup
# ------------------------------------------------------------------

def _find(self, channel_name: str) -> tuple[Callable, dict[str, str]] | None:
"""Return ``(callback, wildcard_kwargs)`` for *channel_name*, or ``None``."""
for compiled, _pattern, callback in self._rules:
m = compiled.match(channel_name)
if m:
return callback, m.groupdict()
return None

# ------------------------------------------------------------------
# Authorization
# ------------------------------------------------------------------

async def authorize(self, channel_name: str, user: Any) -> bool:
"""Evaluate the authorization callback for *channel_name*.

* Public channels (no ``private-`` / ``presence-`` prefix): always
allowed, no callback needed.
* Private / presence channels without a registered callback: **denied**
by default (fail-safe).
* Private / presence channels with a registered callback: the return
value of the callback determines access.

Wildcard values extracted from the pattern are cast to the type hints
of the matching parameters before the callback is called.
"""

is_private = channel_name.startswith("private-") or channel_name.startswith("presence-")

result = self._find(channel_name)

# Public channel with no rule → allow
if result is None and not is_private:
return True

# Private/presence channel with no rule → deny
if result is None:
return False

callback, raw_kwargs = result

# Cast wildcard values to the declared type hints of the callback
hints = {}
try:
sig = inspect.signature(callback)
hints = {
k: p.annotation
for k, p in sig.parameters.items()
if p.annotation is not inspect.Parameter.empty and k != "user"
}
except (ValueError, TypeError):
pass

cast_kwargs: dict[str, Any] = {}
for k, v in raw_kwargs.items():
try:
cast_kwargs[k] = hints[k](v) if k in hints else v
except (ValueError, TypeError):
cast_kwargs[k] = v

# Call the callback (sync or async)
result_value = callback(user, **cast_kwargs)
if inspect.isawaitable(result_value):
result_value = await result_value

return bool(result_value)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ def __init__(self, name: str):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r})"

def __eq__(self, other: object) -> bool:
if not isinstance(other, Channel):
return NotImplemented
return self.name == other.name

def __hash__(self) -> int:
return hash(self.name)


class PrivateChannel(Channel):
def __init__(self, name: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ async def broadcast(self, event) -> None:
logger.info(
f"[Broadcast] channel={channel.name} event={event.broadcast_as()} data={event.broadcast_with()}"
)

async def broadcast_raw(self, channel: str, event_name: str, payload: dict) -> None:
"""Log a raw channel/event/payload triple (used by :meth:`BroadcastManager.emit`)."""
logger.info(f"[Broadcast] channel={channel} event={event_name} data={payload}")
60 changes: 59 additions & 1 deletion fastapi_startkit/src/fastapi_startkit/broadcasting/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from fastapi_startkit.application import app

if TYPE_CHECKING:
pass


class ShouldBroadcast(ABC):
Expand All @@ -7,8 +15,58 @@ def broadcast_on(self) -> list: ...


class BroadcastEvent(ShouldBroadcast):
"""Base class for all broadcastable events.

Subclasses must implement ``broadcast_on()`` returning a list of Channel
objects. At dispatch time the event is pushed to every channel in that
list using the configured broadcast driver.

Attributes:
payload: Arbitrary data dict sent to subscribers. Defaults to an
empty dict; subclasses can populate it in ``__init__`` or by
overriding ``broadcast_with()``.
name: Event name seen by subscribers. Defaults to the class name at
emit time. Override to use a custom string.
"""

payload: dict = {}
name: str | None = None

# ------------------------------------------------------------------
# Driver-layer helpers (used by ReverbDriver / LogDriver)
# ------------------------------------------------------------------

def broadcast_as(self) -> str:
return self.__class__.__name__
"""Return the event name used on the wire.

Uses ``self.name`` when set; otherwise falls back to the class name.
"""
return self.name if self.name is not None else self.__class__.__name__

def broadcast_with(self) -> dict:
"""Return the payload dict.

Returns ``self.payload`` when it is non-empty; otherwise serialises
all instance attributes (backward-compatible behaviour).
"""
if self.payload:
return dict(self.payload)
return {k: v for k, v in self.__dict__.items()}

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

async def emit(self) -> None:
"""Dispatch this event via the BroadcastManager.

Resolves the ``"broadcasting"`` service directly from the Application
container — no facade in the call chain.

Equivalent to::

from fastapi_startkit.application import app
await app().make("broadcasting").dispatch(self)
"""
manager = app().make("broadcasting")
await manager.dispatch(self)
32 changes: 32 additions & 0 deletions fastapi_startkit/src/fastapi_startkit/broadcasting/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,39 @@
from collections.abc import Callable

from ..application import app


async def broadcast(event) -> None:
"""Broadcast an event using the default driver."""
manager = app().make("broadcasting")
await manager.event(event)


def channel(pattern: str) -> Callable:
"""Register a channel authorization callback.

A decorator factory that registers the decorated function with the
``ChannelAuthRegistry`` on the booted ``BroadcastManager``.

Usage::

from fastapi_startkit.broadcasting import channel

@channel("orders.{order_id}")
async def authorize_orders(user, order_id: int) -> bool:
return user is not None and user.id == order_id

@channel("private-notifications")
async def authorize_notifications(user) -> bool:
return user is not None

Pattern supports ``{wildcard}`` placeholders; wildcard values are cast to
the type-hints of the matching callback parameters before calling.
"""

def decorator(func: Callable) -> Callable:
manager = app().make("broadcasting")
manager.channel_registry.register(pattern, func)
return func

return decorator
81 changes: 78 additions & 3 deletions fastapi_startkit/src/fastapi_startkit/broadcasting/manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from .auth import ChannelAuthRegistry

if TYPE_CHECKING:
from .event import BroadcastEvent


class BroadcastManager:
def __init__(self, config, server=None):
"""Central broadcast manager.
Holds the active driver, the Reverb server instance (if any), and the
:class:`~.auth.ChannelAuthRegistry` used by the ``/broadcasting/auth``
endpoint.
This is the service that the :class:`~fastapi_startkit.facades.Broadcast`
facade proxies to (container key: ``"broadcasting"``).
"""

def __init__(self, config: dict, server: Any = None) -> None:
self._config = config
self._server = server
self.channel_registry = ChannelAuthRegistry()

# ------------------------------------------------------------------
# Driver resolution
# ------------------------------------------------------------------

def driver(self, name=None):
def driver(self, name: str | None = None):
name = name or self._config.get("default", "log")
if name == "log":
from .drivers.log_driver import LogDriver
Expand All @@ -15,5 +40,55 @@ def driver(self, name=None):
return ReverbDriver(self._server)
raise ValueError(f"Unknown broadcast driver: {name}")

async def event(self, broadcast_event) -> None:
# ------------------------------------------------------------------
# Dispatch helpers
# ------------------------------------------------------------------

async def event(self, broadcast_event: "BroadcastEvent") -> None:
"""Broadcast *broadcast_event* using the default driver.
.. deprecated::
Prefer :meth:`dispatch` for new code.
"""
await self.driver().broadcast(broadcast_event)

async def dispatch(self, broadcast_event: "BroadcastEvent") -> None:
"""Primary dispatch path.
Sends *broadcast_event* to all channels returned by its
``broadcast_on()`` method via the configured driver.
"""
await self.driver().broadcast(broadcast_event)

async def emit(self, channel: str, event_name: str, payload: dict) -> None:
"""Escape-hatch for direct / dynamic broadcasts.
Sends *payload* as *event_name* to *channel* without needing a
``BroadcastEvent`` subclass.
"""
driver_name = self._config.get("default", "log")
if driver_name == "reverb" and self._server is not None:
await self._server.broadcast_to_channel(channel, event_name, payload)
else:
# Log driver fallback
from .drivers.log_driver import LogDriver

driver = LogDriver()
await driver.broadcast_raw(channel, event_name, payload)

# ------------------------------------------------------------------
# Channel authorization registry proxy
# ------------------------------------------------------------------

def channel(self, pattern: str):
"""Decorator factory — register a channel authorization callback.
Example::
Broadcast.channel("orders.{order_id}")
async def authorize(user, order_id: int):
return user.id == order_id
Delegates to :attr:`channel_registry`.
"""
return self.channel_registry.channel(pattern)
Loading
Loading