diff --git a/ably/__init__.py b/ably/__init__.py index 5c60ef3b..d5ee1736 100644 --- a/ably/__init__.py +++ b/ably/__init__.py @@ -4,7 +4,10 @@ from ably.rest.auth import Auth from ably.rest.push import Push from ably.rest.rest import AblyRest +from ably.types.annotation import Annotation, AnnotationAction from ably.types.capability import Capability +from ably.types.channelmode import ChannelMode +from ably.types.channeloptions import ChannelOptions from ably.types.channelsubscription import PushChannelSubscription from ably.types.device import DeviceDetails from ably.types.message import MessageAction, MessageVersion diff --git a/ably/realtime/annotations.py b/ably/realtime/annotations.py new file mode 100644 index 00000000..fbbbb755 --- /dev/null +++ b/ably/realtime/annotations.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ably.rest.annotations import RestAnnotations, construct_validate_annotation +from ably.transport.websockettransport import ProtocolMessageAction +from ably.types.annotation import Annotation, AnnotationAction +from ably.types.channelmode import ChannelMode +from ably.types.channelstate import ChannelState +from ably.util.eventemitter import EventEmitter +from ably.util.helper import is_callable_or_coroutine + +if TYPE_CHECKING: + from ably.realtime.channel import RealtimeChannel + from ably.realtime.connectionmanager import ConnectionManager + +log = logging.getLogger(__name__) + + +class RealtimeAnnotations: + """ + Provides realtime methods for managing annotations on messages, + including publishing annotations and subscribing to annotation events. + """ + + __connection_manager: ConnectionManager + __channel: RealtimeChannel + + def __init__(self, channel: RealtimeChannel, connection_manager: ConnectionManager): + """ + Initialize RealtimeAnnotations. + + Args: + channel: The Realtime Channel this annotations instance belongs to + """ + self.__channel = channel + self.__connection_manager = connection_manager + self.__subscriptions = EventEmitter() + self.__rest_annotations = RestAnnotations(channel) + + async def __send_annotation(self, annotation: Annotation, params: dict | None = None): + """ + Internal method to send an annotation via the realtime connection. + + Args: + annotation: Validated Annotation object with action and message_serial set + params: Optional dict of query parameters + """ + # Check if channel and connection are in publishable state + self.__channel._throw_if_unpublishable_state() + + log.info( + f'RealtimeAnnotations: sending annotation, channelName = {self.__channel.name}, ' + f'messageSerial = {annotation.message_serial}, ' + f'type = {annotation.type}, action = {annotation.action}' + ) + + # Convert to wire format (array of annotations) + wire_annotation = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol) + + # Build protocol message + protocol_message = { + "action": ProtocolMessageAction.ANNOTATION, + "channel": self.__channel.name, + "annotations": [wire_annotation], + } + + if params: + # Stringify boolean params + stringified_params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()} + protocol_message["params"] = stringified_params + + # Send via WebSocket + await self.__connection_manager.send_protocol_message(protocol_message) + + async def publish(self, msg_or_serial, annotation: Annotation, params: dict | None = None): + """ + Publish an annotation on a message via the realtime connection. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Annotation object + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails, inputs are invalid, or channel is in unpublishable state + """ + annotation = construct_validate_annotation(msg_or_serial, annotation) + + # RSAN1c1/RTAN1a: Explicitly set action to ANNOTATION_CREATE + annotation = annotation._copy_with(action=AnnotationAction.ANNOTATION_CREATE) + + await self.__send_annotation(annotation, params) + + async def delete( + self, + msg_or_serial, + annotation: Annotation, + params: dict | None = None, + ): + """ + Delete an annotation on a message. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Annotation containing annotation properties + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + annotation = construct_validate_annotation(msg_or_serial, annotation) + + # RSAN2a/RTAN2a: Explicitly set action to ANNOTATION_DELETE + annotation = annotation._copy_with(action=AnnotationAction.ANNOTATION_DELETE) + + await self.__send_annotation(annotation, params) + + async def subscribe(self, *args): + """ + Subscribe to annotation events on this channel. + + Parameters + ---------- + *args: type_or_types, listener + Subscribe type(s) and listener + + arg1(type_or_types): str or list[str], optional + Subscribe to annotations of the given type or types (RTAN4c) + + arg2(listener): callable + Subscribe to all annotations on the channel + + When no type is provided, arg1 is used as the listener. + + Raises + ------ + ValueError + If no valid subscribe arguments are passed + """ + # Parse arguments similar to channel.subscribe + if len(args) == 0: + raise ValueError("annotations.subscribe called without arguments") + + annotation_types = None + + # RTAN4c: Support string or list of strings as first argument + if len(args) >= 2 and isinstance(args[0], (str, list)): + if isinstance(args[0], list): + annotation_types = args[0] + else: + annotation_types = [args[0]] + if not args[1]: + raise ValueError("annotations.subscribe called without listener") + if not is_callable_or_coroutine(args[1]): + raise ValueError("subscribe listener must be function or coroutine function") + listener = args[1] + elif is_callable_or_coroutine(args[0]): + listener = args[0] + else: + raise ValueError('invalid subscribe arguments') + + # RTAN4d: Implicitly attach channel on subscribe + await self.__channel.attach() + + # RTAN4e: Check if ANNOTATION_SUBSCRIBE mode is enabled (log warning per spec), + # only when server explicitly sent modes (non-empty list) + if self.__channel.state == ChannelState.ATTACHED and self.__channel.modes: + if ChannelMode.ANNOTATION_SUBSCRIBE not in self.__channel.modes: + log.warning( + "You are trying to add an annotation listener, but the " + "ANNOTATION_SUBSCRIBE channel mode was not included in the ATTACHED flags. " + "This subscription may not receive annotations. Ensure you request the " + "annotation_subscribe channel mode in ChannelOptions." + ) + + # Register subscription after successful attach + if annotation_types is not None: + for t in annotation_types: + self.__subscriptions.on(t, listener) + else: + self.__subscriptions.on(listener) + + def unsubscribe(self, *args): + """ + Unsubscribe from annotation events on this channel. + + Parameters + ---------- + *args: type_or_types, listener + Unsubscribe type(s) and listener + + arg1(type_or_types): str or list[str], optional + Unsubscribe from annotations of the given type or types + + arg2(listener): callable + Unsubscribe from all annotations on the channel + + When no type is provided, arg1 is used as the listener. + When no arguments are provided, unsubscribes all annotation listeners (RTAN5). + + Raises + ------ + ValueError + If invalid unsubscribe arguments are passed + """ + # RTAN5: Support no arguments to unsubscribe all annotation listeners + if len(args) == 0: + self.__subscriptions.off() + elif len(args) >= 2 and isinstance(args[0], (str, list)): + # RTAN5a: Support string or list of strings for type(s) + if isinstance(args[0], list): + annotation_types = args[0] + else: + annotation_types = [args[0]] + listener = args[1] + for t in annotation_types: + self.__subscriptions.off(t, listener) + elif is_callable_or_coroutine(args[0]): + listener = args[0] + self.__subscriptions.off(listener) + else: + raise ValueError('invalid unsubscribe arguments') + + def _process_incoming(self, incoming_annotations): + """ + Process incoming annotations from the server. + + This is called internally when ANNOTATION protocol messages are received. + + Args: + incoming_annotations: List of Annotation objects received from the server + """ + for annotation in incoming_annotations: + # Emit to type-specific listeners and catch-all listeners + annotation_type = annotation.type or '' + self.__subscriptions._emit(annotation_type, annotation) + + async def get(self, msg_or_serial, params: dict | None = None): + """ + Retrieve annotations for a message with pagination support. + + This delegates to the REST implementation. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + params: Optional dict of query parameters (limit, start, end, direction) + + Returns: + PaginatedResult: A paginated result containing Annotation objects + + Raises: + AblyException: If the request fails or serial is invalid + """ + # Delegate to REST implementation + return await self.__rest_annotations.get(msg_or_serial, params) diff --git a/ably/realtime/channel.py b/ably/realtime/channel.py index e0fd6251..768eeb7d 100644 --- a/ably/realtime/channel.py +++ b/ably/realtime/channel.py @@ -4,10 +4,14 @@ import logging from typing import TYPE_CHECKING +from ably.realtime.annotations import RealtimeAnnotations from ably.realtime.connection import ConnectionState +from ably.realtime.presence import RealtimePresence from ably.rest.channel import Channel from ably.rest.channel import Channels as RestChannels from ably.transport.websockettransport import ProtocolMessageAction +from ably.types.annotation import Annotation +from ably.types.channelmode import ChannelMode, decode_channel_mode, encode_channel_mode from ably.types.channeloptions import ChannelOptions from ably.types.channelstate import ChannelState, ChannelStateChange from ably.types.flags import Flag, has_flag @@ -64,6 +68,7 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp self.__error_reason: AblyException | None = None self.__channel_options = channel_options or ChannelOptions() self.__params: dict[str, str] | None = None + self.__modes: list[ChannelMode] = [] # Channel mode flags from ATTACHED message # Delta-specific fields for RTL19/RTL20 compliance vcdiff_decoder = self.__realtime.options.vcdiff_decoder if self.__realtime.options.vcdiff_decoder else None @@ -74,12 +79,15 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp # will be disrupted if the user called .off() to remove all listeners self.__internal_state_emitter = EventEmitter() + # Pass channel options as dictionary to parent Channel class + Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) + # Initialize presence for this channel - from ably.realtime.presence import RealtimePresence + self.__presence = RealtimePresence(self) - # Pass channel options as dictionary to parent Channel class - Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) + # Initialize realtime annotations for this channel (override REST annotations) + self._Channel__annotations = RealtimeAnnotations(self, realtime.connection.connection_manager) async def set_options(self, channel_options: ChannelOptions) -> None: """Set channel options""" @@ -149,8 +157,10 @@ def _attach_impl(self): "channel": self.name, } - if self.__attach_resume: - attach_msg["flags"] = Flag.ATTACH_RESUME + flags = self._encode_flags() + + if flags: + attach_msg["flags"] = flags if self.__channel_serial: attach_msg["channelSerial"] = self.__channel_serial @@ -491,8 +501,8 @@ async def _send_update( if not message.serial: raise AblyException( "Message serial is required for update/delete/append operations", - 400, - 40003 + status_code=400, + code=40003, ) # Check connection and channel state @@ -530,7 +540,7 @@ async def _send_update( f'channel = {self.name}, state = {self.state}, serial = {message.serial}' ) - stringified_params = {k: str(v).lower() if type(v) is bool else v for k, v in params.items()} \ + stringified_params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()} \ if params else None # Send protocol message @@ -702,6 +712,8 @@ def _on_message(self, proto_msg: dict) -> None: resumed = has_flag(flags, Flag.RESUMED) # RTP1: Check for HAS_PRESENCE flag has_presence = has_flag(flags, Flag.HAS_PRESENCE) + # Store channel attach flags + self.__modes = decode_channel_mode(flags) # RTL12 if self.state == ChannelState.ATTACHED: @@ -744,6 +756,19 @@ def _on_message(self, proto_msg: dict) -> None: decoded_presence = PresenceMessage.from_encoded_array(presence_messages, cipher=self.cipher) sync_channel_serial = proto_msg.get('channelSerial') self.__presence.set_presence(decoded_presence, is_sync=True, sync_channel_serial=sync_channel_serial) + elif action == ProtocolMessageAction.ANNOTATION: + # Handle ANNOTATION messages + # RTAN4b: Populate annotation fields from protocol message + Annotation.update_inner_annotation_fields(proto_msg) + annotation_data = proto_msg.get('annotations', []) + try: + annotations = Annotation.from_encoded_array(annotation_data, cipher=self.cipher) + # Process annotations through the annotations handler + self.annotations._process_incoming(annotations) + # RTL15b: Update channel serial for ANNOTATION messages + self.__channel_serial = channel_serial + except Exception as e: + log.error(f"Annotation processing error {e}. Skip annotations {annotation_data}") elif action == ProtocolMessageAction.ERROR: error = AblyException.from_dict(proto_msg.get('error')) self._notify_state(ChannelState.FAILED, reason=error) @@ -890,6 +915,15 @@ def presence(self): """Get the RealtimePresence object for this channel""" return self.__presence + @property + def annotations(self) -> RealtimeAnnotations: + return self._Channel__annotations + + @property + def modes(self): + """Get the list of channel modes""" + return self.__modes + def _start_decode_failure_recovery(self, error: AblyException) -> None: """Start RTL18 decode failure recovery procedure""" @@ -908,6 +942,20 @@ def _start_decode_failure_recovery(self, error: AblyException) -> None: self._notify_state(ChannelState.ATTACHING, reason=error) self._check_pending_state() + def _encode_flags(self) -> int | None: + if not self.__channel_options.modes and not self.__attach_resume: + return None + + flags = 0 + + if self.__attach_resume: + flags |= Flag.ATTACH_RESUME + + if self.__channel_options.modes: + flags |= encode_channel_mode(self.__channel_options.modes) + + return flags + class Channels(RestChannels): """Creates and destroys RealtimeChannel objects. diff --git a/ably/rest/annotations.py b/ably/rest/annotations.py new file mode 100644 index 00000000..fc2b29d5 --- /dev/null +++ b/ably/rest/annotations.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import base64 +import json +import logging +import os +from urllib import parse + +import msgpack + +from ably.http.paginatedresult import PaginatedResult, format_params +from ably.types.annotation import ( + Annotation, + AnnotationAction, + make_annotation_response_handler, +) +from ably.types.message import Message +from ably.types.options import Options +from ably.util.exceptions import AblyException + +log = logging.getLogger(__name__) + + +def serial_from_msg_or_serial(msg_or_serial): + """ + Extract the message serial from either a string serial or a Message object. + + Args: + msg_or_serial: Either a string serial or a Message object with a serial property + + Returns: + str: The message serial + + Raises: + AblyException: If the input is invalid or serial is missing + """ + if isinstance(msg_or_serial, str): + message_serial = msg_or_serial + elif isinstance(msg_or_serial, Message): + message_serial = msg_or_serial.serial + else: + message_serial = None + + if not message_serial or not isinstance(message_serial, str): + raise AblyException( + message='First argument of annotations.publish() must be either a Message ' + 'or a message serial (string)', + status_code=400, + code=40003, + ) + + return message_serial + + +def construct_validate_annotation(msg_or_serial, annotation: Annotation) -> Annotation: + """ + Construct and validate an Annotation from input values. + + Args: + msg_or_serial: Either a string serial or a Message object + annotation: Annotation object + + Returns: + Annotation: The constructed annotation + + Raises: + AblyException: If the inputs are invalid + """ + message_serial = serial_from_msg_or_serial(msg_or_serial) + + if not annotation or not isinstance(annotation, Annotation): + raise AblyException( + message='Second argument of annotations.publish() must be an Annotation ' + '(the intended annotation to publish)', + status_code=400, + code=40003, + ) + + # RSAN1a3: Validate that annotation type is specified + if not annotation.type: + raise AblyException( + message='Annotation type must be specified', + status_code=400, + code=40000, + ) + + return annotation._copy_with( + message_serial=message_serial, + ) + + +class RestAnnotations: + """ + Provides REST API methods for managing annotations on messages. + """ + + __client_options: Options + + def __init__(self, channel): + """ + Initialize RestAnnotations. + + Args: + channel: The REST Channel this annotations instance belongs to + """ + self.__channel = channel + self.__client_options = channel.ably.options + + def __base_path_for_serial(self, serial): + """ + Build the base API path for a message serial's annotations. + + Args: + serial: The message serial + + Returns: + str: The API path + """ + channel_path = '/channels/{}/'.format(parse.quote_plus(self.__channel.name, safe=':')) + return channel_path + 'messages/' + parse.quote_plus(serial, safe=':') + '/annotations' + + async def __send_annotation(self, annotation: Annotation, params: dict | None = None): + """ + Internal method to send an annotation to the API. + + Args: + annotation: Validated Annotation object with action and message_serial set + params: Optional dict of query parameters + """ + # RSAN1c4: Generate random ID if not provided (for idempotent publishing) + # Spec: base64-encode at least 9 random bytes, append ':0' + if not annotation.id and self.__client_options.idempotent_rest_publishing: + random_id = base64.b64encode(os.urandom(9)).decode('ascii') + ':0' + annotation = annotation._copy_with(id=random_id) + + # Convert to wire format + request_body = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol) + + # Wrap in array as API expects array of annotations + request_body = [request_body] + + # Encode based on protocol + if not self.__channel.ably.options.use_binary_protocol: + request_body = json.dumps(request_body, separators=(',', ':')) + else: + request_body = msgpack.packb(request_body, use_bin_type=True) + + # Build path + path = self.__base_path_for_serial(annotation.message_serial) + if params: + params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()} + path += '?' + parse.urlencode(params) + + # Send request + await self.__channel.ably.http.post(path, body=request_body) + + async def publish( + self, + msg_or_serial, + annotation: Annotation, + params: dict | None = None, + ): + """ + Publish an annotation on a message. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Annotation object + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + annotation = construct_validate_annotation(msg_or_serial, annotation) + + # RSAN1c1: Explicitly set action to ANNOTATION_CREATE + annotation = annotation._copy_with(action=AnnotationAction.ANNOTATION_CREATE) + + await self.__send_annotation(annotation, params) + + async def delete( + self, + msg_or_serial, + annotation: Annotation, + params: dict | None = None, + ): + """ + Delete an annotation on a message. + + This is a convenience method that sets the action to 'annotation.delete' + and calls publish(). + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Annotation object + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + annotation = construct_validate_annotation(msg_or_serial, annotation) + + # RSAN2a: Explicitly set action to ANNOTATION_DELETE + annotation = annotation._copy_with(action=AnnotationAction.ANNOTATION_DELETE) + + return await self.__send_annotation(annotation, params) + + async def get(self, msg_or_serial, params: dict | None = None): + """ + Retrieve annotations for a message with pagination support. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + params: Optional dict of query parameters (limit, start, end, direction) + + Returns: + PaginatedResult: A paginated result containing Annotation objects + + Raises: + AblyException: If the request fails or serial is invalid + """ + message_serial = serial_from_msg_or_serial(msg_or_serial) + + # Build path + params_str = format_params({}, **params) if params else '' + path = self.__base_path_for_serial(message_serial) + params_str + + # Create annotation response handler + annotation_handler = make_annotation_response_handler(cipher=None) + + # Return paginated result + return await PaginatedResult.paginated_query( + self.__channel.ably.http, + url=path, + response_processor=annotation_handler + ) diff --git a/ably/rest/auth.py b/ably/rest/auth.py index 2aaa4b12..d2057533 100644 --- a/ably/rest/auth.py +++ b/ably/rest/auth.py @@ -89,8 +89,8 @@ def __init__(self, ably: AblyRest | AblyRealtime, options: Options): async def get_auth_transport_param(self): auth_credentials = {} - if self.auth_options.client_id: - auth_credentials["client_id"] = self.auth_options.client_id + if self.auth_options.client_id and self.auth_options.client_id != '*': + auth_credentials["clientId"] = self.auth_options.client_id if self.__auth_mechanism == Auth.Method.BASIC: key_name = self.__auth_options.key_name key_secret = self.__auth_options.key_secret diff --git a/ably/rest/channel.py b/ably/rest/channel.py index 2c1c0246..f6b118b7 100644 --- a/ably/rest/channel.py +++ b/ably/rest/channel.py @@ -9,6 +9,7 @@ import msgpack from ably.http.paginatedresult import PaginatedResult, format_params +from ably.rest.annotations import RestAnnotations from ably.types.channeldetails import ChannelDetails from ably.types.message import ( Message, @@ -30,6 +31,8 @@ class Channel: + __annotations: RestAnnotations + def __init__(self, ably, name, options): self.__ably = ably self.__name = name @@ -37,6 +40,7 @@ def __init__(self, ably, name, options): self.__cipher = None self.options = options self.__presence = Presence(self) + self.__annotations = RestAnnotations(self) @catch_all async def history(self, direction=None, limit: int = None, start=None, end=None): @@ -108,7 +112,7 @@ async def publish_messages(self, messages, params=None, timeout=None): path = self.__base_path + 'messages' if params: - params = {k: str(v).lower() if type(v) is bool else v for k, v in params.items()} + params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()} path += '?' + parse.urlencode(params) response = await self.ably.http.post(path, body=request_body, timeout=timeout) @@ -169,8 +173,8 @@ async def _send_update( if not message.serial: raise AblyException( "Message serial is required for update/delete/append operations", - 400, - 40003 + status_code=400, + code=40003, ) if not operation: @@ -207,7 +211,7 @@ async def _send_update( # Build path with params path = self.__base_path + 'messages/{}'.format(parse.quote_plus(message.serial, safe=':')) if params: - params = {k: str(v).lower() if type(v) is bool else v for k, v in params.items()} + params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()} path += '?' + parse.urlencode(params) # Send request @@ -282,8 +286,8 @@ async def get_message(self, serial_or_message, timeout=None): raise AblyException( 'This message lacks a serial. Make sure you have enabled "Message annotations, ' 'updates, and deletes" in channel settings on your dashboard.', - 400, - 40003 + status_code=400, + code=40003, ) # Build the path @@ -321,8 +325,8 @@ async def get_message_versions(self, serial_or_message, params=None): raise AblyException( 'This message lacks a serial. Make sure you have enabled "Message annotations, ' 'updates, and deletes" in channel settings on your dashboard.', - 400, - 40003 + status_code=400, + code=40003, ) # Build the path @@ -363,6 +367,10 @@ def options(self): def presence(self): return self.__presence + @property + def annotations(self) -> RestAnnotations: + return self.__annotations + @options.setter def options(self, options): self.__options = options diff --git a/ably/transport/websockettransport.py b/ably/transport/websockettransport.py index 4f6f9fe0..be13d096 100644 --- a/ably/transport/websockettransport.py +++ b/ably/transport/websockettransport.py @@ -189,6 +189,7 @@ async def on_protocol_message(self, msg): ProtocolMessageAction.DETACHED, ProtocolMessageAction.MESSAGE, ProtocolMessageAction.PRESENCE, + ProtocolMessageAction.ANNOTATION, ProtocolMessageAction.SYNC ): self.connection_manager.on_channel_message(msg) diff --git a/ably/types/annotation.py b/ably/types/annotation.py new file mode 100644 index 00000000..c0926f58 --- /dev/null +++ b/ably/types/annotation.py @@ -0,0 +1,336 @@ +import logging +from enum import IntEnum + +from ably.types.mixins import EncodeDataMixin +from ably.util.encoding import encode_data +from ably.util.helper import to_text + +log = logging.getLogger(__name__) + + +# Sentinel value to distinguish between "not provided" and "explicitly None" +_UNSET = object() + + +class AnnotationAction(IntEnum): + """Annotation action types""" + ANNOTATION_CREATE = 0 + ANNOTATION_DELETE = 1 + + +class Annotation(EncodeDataMixin): + """ + Represents an annotation on a message, such as a reaction or other metadata. + + Annotations are not encrypted as they need to be parsed by the server for summarization. + """ + + def __init__(self, + action=None, + serial=None, + message_serial=None, + type=None, + name=None, + count=None, + data=None, + encoding='', + id=None, + client_id=None, + connection_id=None, + timestamp=None, + extras=None): + """ + Args: + action: The action type - either 'annotation.create' or 'annotation.delete' + serial: A unique identifier for the annotation + message_serial: The serial of the message this annotation is for + type: The type of annotation (e.g., 'reaction', 'like', etc.) + name: The name/value of the annotation (e.g., specific emoji) + count: Count associated with the annotation + data: Optional data payload for the annotation + encoding: Encoding format for the data + id: (TAN2a) A unique identifier for this annotation + client_id: The client ID that created this annotation + connection_id: The connection ID that created this annotation + timestamp: Timestamp of the annotation + extras: Additional metadata + """ + super().__init__(encoding) + + self.__serial = to_text(serial) if serial is not None else None + self.__message_serial = to_text(message_serial) if message_serial is not None else None + self.__type = to_text(type) if type is not None else None + self.__name = to_text(name) if name is not None else None + self.__action = action if action is not None else AnnotationAction.ANNOTATION_CREATE + self.__count = count + self.__data = data + self.__id = to_text(id) if id is not None else None + self.__client_id = to_text(client_id) if client_id is not None else None + self.__connection_id = to_text(connection_id) if connection_id is not None else None + self.__timestamp = timestamp + self.__extras = extras + self.__encoding = encoding + + def __eq__(self, other): + if isinstance(other, Annotation): + # TAN2i: serial is the unique identifier for the annotation + # If both have serials, use serial for comparison + if self.serial is not None and other.serial is not None: + return self.serial == other.serial + # Otherwise fall back to comparing multiple fields + return (self.message_serial == other.message_serial + and self.type == other.type + and self.name == other.name + and self.action == other.action + and self.client_id == other.client_id) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, Annotation): + result = self.__eq__(other) + if result != NotImplemented: + return not result + return NotImplemented + + @property + def action(self): + return self.__action + + @property + def serial(self): + return self.__serial + + @property + def message_serial(self): + return self.__message_serial + + @property + def type(self): + return self.__type + + @property + def name(self): + return self.__name + + @property + def count(self): + return self.__count + + @property + def data(self): + return self.__data + + @property + def client_id(self): + return self.__client_id + + @property + def timestamp(self): + return self.__timestamp + + @property + def extras(self): + return self.__extras + + @property + def id(self): + return self.__id + + @property + def connection_id(self): + return self.__connection_id + + def as_dict(self, binary=False): + """ + Convert annotation to dictionary format for API communication. + + Note: Annotations are not encrypted as they need to be parsed by the server. + """ + request_body = { + 'action': int(self.action) if self.action is not None else None, + 'serial': self.serial, + 'messageSerial': self.message_serial, + 'type': self.type, # Annotation type (not data type) + 'name': self.name, + 'count': self.count, + 'id': self.id or None, + 'clientId': self.client_id or None, + 'connectionId': self.connection_id or None, + 'timestamp': self.timestamp or None, + 'extras': self.extras, + **encode_data(self.data, self._encoding_array, binary) + } + + # None values aren't included + request_body = {k: v for k, v in request_body.items() if v is not None} + + return request_body + + @staticmethod + def from_encoded(obj, cipher=None, context=None): + """ + Create an Annotation from an encoded object received from the API. + + Note: cipher parameter is accepted for consistency but annotations are not encrypted. + """ + action = obj.get('action') + serial = obj.get('serial') + message_serial = obj.get('messageSerial') + type_val = obj.get('type') + name = obj.get('name') + count = obj.get('count') + data = obj.get('data') + encoding = obj.get('encoding', '') + id = obj.get('id') + client_id = obj.get('clientId') + connection_id = obj.get('connectionId') + timestamp = obj.get('timestamp') + extras = obj.get('extras', None) + + # Decode data if present, passing data=None explicitly when absent + decoded_data = Annotation.decode(data, encoding, cipher, context) if data is not None else {'data': None} + + # Convert action from int to enum + if action is not None: + try: + action = AnnotationAction(action) + except ValueError: + # If it's not a valid action value, store as None + action = None + else: + action = None + + return Annotation( + action=action, + serial=serial, + message_serial=message_serial, + type=type_val, + name=name, + count=count, + id=id, + client_id=client_id, + connection_id=connection_id, + timestamp=timestamp, + extras=extras, + **decoded_data + ) + + @staticmethod + def from_encoded_array(obj_array, cipher=None, context=None): + """Create an array of Annotations from encoded objects""" + return [Annotation.from_encoded(obj, cipher, context) for obj in obj_array] + + @staticmethod + def from_values(values): + """Create an Annotation from a dict of values""" + return Annotation(**values) + + @staticmethod + def __update_empty_fields(proto_msg: dict, annotation: dict, annotation_index: int): + """Update empty annotation fields with values from protocol message""" + if annotation.get("id") is None or annotation.get("id") == '': + annotation['id'] = f"{proto_msg.get('id')}:{annotation_index}" + if annotation.get("connectionId") is None or annotation.get("connectionId") == '': + annotation['connectionId'] = proto_msg.get('connectionId') + if annotation.get("timestamp") is None or annotation.get("timestamp") == 0: + annotation['timestamp'] = proto_msg.get('timestamp') + + @staticmethod + def update_inner_annotation_fields(proto_msg: dict): + """ + Update inner annotation fields with protocol message data (RTAN4b). + + Populates empty id, connectionId, and timestamp fields in annotations + from the protocol message values. + """ + annotations: list[dict] = proto_msg.get('annotations') + if annotations is not None: + for annotation_index, annotation in enumerate(annotations): + Annotation.__update_empty_fields(proto_msg, annotation, annotation_index) + + def __str__(self): + return ( + f"Annotation(action={self.action}, messageSerial={self.message_serial}, " + f"type={self.type}, name={self.name})" + ) + + def __repr__(self): + return self.__str__() + + def _copy_with(self, + action=_UNSET, + serial=_UNSET, + message_serial=_UNSET, + type=_UNSET, + name=_UNSET, + count=_UNSET, + data=_UNSET, + encoding=_UNSET, + id=_UNSET, + client_id=_UNSET, + connection_id=_UNSET, + timestamp=_UNSET, + extras=_UNSET): + """ + Create a copy of this Annotation with optionally modified fields. + + To explicitly set a field to None, pass None as the value. + Fields not provided will retain their original values. + + Args: + action: Override the action type (or None to clear it) + serial: Override the serial (or None to clear it) + message_serial: Override the message serial (or None to clear it) + type: Override the type (or None to clear it) + name: Override the name (or None to clear it) + count: Override the count (or None to clear it) + data: Override the data payload (or None to clear it) + encoding: Override the encoding format (or None to clear it) + id: Override the ID (or None to clear it) + client_id: Override the client ID (or None to clear it) + connection_id: Override the connection ID (or None to clear it) + timestamp: Override the timestamp (or None to clear it) + extras: Override the extras metadata (or None to clear it) + + Returns: + A new Annotation instance with the specified fields updated + + Example: + # Keep existing name, change type + new_ann = annotation.copy_with(type="like") + + # Explicitly set name to None + new_ann = annotation.copy_with(name=None) + """ + # Get encoding from the mixin's property + return Annotation( + action=self.__action if action is _UNSET else action, + serial=self.__serial if serial is _UNSET else serial, + message_serial=self.__message_serial if message_serial is _UNSET else message_serial, + type=self.__type if type is _UNSET else type, + name=self.__name if name is _UNSET else name, + count=self.__count if count is _UNSET else count, + data=self.__data if data is _UNSET else data, + encoding=self.__encoding if encoding is _UNSET else encoding, + id=self.__id if id is _UNSET else id, + client_id=self.__client_id if client_id is _UNSET else client_id, + connection_id=self.__connection_id if connection_id is _UNSET else connection_id, + timestamp=self.__timestamp if timestamp is _UNSET else timestamp, + extras=self.__extras if extras is _UNSET else extras, + ) + + +def make_annotation_response_handler(cipher=None): + """Create a response handler for annotation API responses""" + def annotation_response_handler(response): + annotations = response.to_native() + return Annotation.from_encoded_array(annotations, cipher=cipher) + return annotation_response_handler + + +def make_single_annotation_response_handler(cipher=None): + """Create a response handler for single annotation API responses""" + def single_annotation_response_handler(response): + annotation = response.to_native() + return Annotation.from_encoded(annotation, cipher=cipher) + return single_annotation_response_handler diff --git a/ably/types/channelmode.py b/ably/types/channelmode.py new file mode 100644 index 00000000..23ed735c --- /dev/null +++ b/ably/types/channelmode.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from enum import Enum + +from ably.types.flags import Flag + + +class ChannelMode(int, Enum): + PRESENCE = Flag.PRESENCE + PUBLISH = Flag.PUBLISH + SUBSCRIBE = Flag.SUBSCRIBE + PRESENCE_SUBSCRIBE = Flag.PRESENCE_SUBSCRIBE + ANNOTATION_PUBLISH = Flag.ANNOTATION_PUBLISH + ANNOTATION_SUBSCRIBE = Flag.ANNOTATION_SUBSCRIBE + + +def encode_channel_mode(modes: list[ChannelMode]) -> int: + """ + Encode a list of ChannelMode values into a bitmask. + + Args: + modes: List of ChannelMode values to encode + + Returns: + Integer bitmask with the corresponding flags set + """ + flags = 0 + + for mode in modes: + flags |= mode.value + + return flags + + +def decode_channel_mode(flags: int) -> list[ChannelMode]: + """ + Decode channel mode flags from a bitmask into a list of ChannelMode values. + + Args: + flags: Integer bitmask containing channel mode flags + + Returns: + List of ChannelMode values that are set in the flags + """ + modes = [] + + # Check each channel mode flag + for mode in ChannelMode: + if flags & mode.value: + modes.append(mode) + + return modes diff --git a/ably/types/channeloptions.py b/ably/types/channeloptions.py index 48e34dfe..3e5052c6 100644 --- a/ably/types/channeloptions.py +++ b/ably/types/channeloptions.py @@ -2,6 +2,7 @@ from typing import Any +from ably.types.channelmode import ChannelMode from ably.util.crypto import CipherParams from ably.util.exceptions import AblyException @@ -17,36 +18,48 @@ class ChannelOptions: Channel parameters that configure the behavior of the channel. """ - def __init__(self, cipher: CipherParams | None = None, params: dict | None = None): + def __init__( + self, + cipher: CipherParams | None = None, + params: dict | None = None, + modes: list[ChannelMode] | None = None + ): self.__cipher = cipher self.__params = params + self.__modes = modes # Validate params if self.__params and not isinstance(self.__params, dict): raise AblyException("params must be a dictionary", 40000, 400) @property - def cipher(self): + def cipher(self) -> CipherParams | None: """Get cipher configuration""" return self.__cipher @property - def params(self) -> dict[str, str]: + def params(self) -> dict[str, str] | None: """Get channel parameters""" return self.__params + @property + def modes(self) -> list[ChannelMode] | None: + """Get channel modes""" + return self.__modes + def __eq__(self, other): """Check equality with another ChannelOptions instance""" if not isinstance(other, ChannelOptions): return False return (self.__cipher == other.__cipher and - self.__params == other.__params) + self.__params == other.__params and self.__modes == other.__modes) def __hash__(self): """Make ChannelOptions hashable""" return hash(( self.__cipher, tuple(sorted(self.__params.items())) if self.__params else None, + tuple(sorted(self.__modes)) if self.__modes else None )) def to_dict(self) -> dict[str, Any]: @@ -56,6 +69,8 @@ def to_dict(self) -> dict[str, Any]: result['cipher'] = self.__cipher if self.__params: result['params'] = self.__params + if self.__modes: + result['modes'] = self.__modes return result @classmethod @@ -67,4 +82,5 @@ def from_dict(cls, options_dict: dict[str, Any]) -> ChannelOptions: return cls( cipher=options_dict.get('cipher'), params=options_dict.get('params'), + modes=options_dict.get('modes'), ) diff --git a/ably/types/flags.py b/ably/types/flags.py index 1666434c..86666019 100644 --- a/ably/types/flags.py +++ b/ably/types/flags.py @@ -13,6 +13,8 @@ class Flag(int, Enum): PUBLISH = 1 << 17 SUBSCRIBE = 1 << 18 PRESENCE_SUBSCRIBE = 1 << 19 + ANNOTATION_PUBLISH = 1 << 21 + ANNOTATION_SUBSCRIBE = 1 << 22 def has_flag(message_flags: int, flag: Flag): diff --git a/ably/types/message.py b/ably/types/message.py index 11caba57..2442a587 100644 --- a/ably/types/message.py +++ b/ably/types/message.py @@ -1,25 +1,50 @@ -import base64 -import json import logging from enum import IntEnum from ably.types.mixins import DeltaExtras, EncodeDataMixin from ably.types.typedbuffer import TypedBuffer from ably.util.crypto import CipherData +from ably.util.encoding import encode_data from ably.util.exceptions import AblyException +from ably.util.helper import to_text log = logging.getLogger(__name__) -def to_text(value): - if value is None: - return value - elif isinstance(value, str): - return value - elif isinstance(value, bytes): - return value.decode() - else: - raise TypeError(f"expected string or bytes, not {type(value)}") +class MessageAnnotations: + """ + Contains information about annotations associated with a particular message. + """ + + def __init__(self, summary=None): + """ + Args: + summary: A dict mapping annotation types to their aggregated values. + The keys are annotation types (e.g., "reaction:distinct.v1"). + The values depend on the aggregation method of the annotation type. + """ + # TM8a: Ensure summary exists + self.__summary = summary if summary is not None else {} + + @property + def summary(self): + """A dict of annotation type to aggregated annotation values.""" + return self.__summary + + def as_dict(self): + """Convert MessageAnnotations to dictionary format.""" + return { + 'summary': self.summary, + } + + @staticmethod + def from_dict(obj): + """Create MessageAnnotations from dictionary.""" + if obj is None: + return MessageAnnotations() + return MessageAnnotations( + summary=obj.get('summary'), + ) class MessageVersion: @@ -122,6 +147,7 @@ def __init__(self, serial=None, # TM2r action=None, # TM2j version=None, # TM2s + annotations=None, # TM2t ): super().__init__(encoding) @@ -137,6 +163,7 @@ def __init__(self, self.__serial = serial self.__action = action self.__version = version + self.__annotations = annotations def __eq__(self, other): if isinstance(other, Message): @@ -201,6 +228,10 @@ def serial(self): def action(self): return self.__action + @property + def annotations(self): + return self.__annotations + def encrypt(self, channel_cipher): if isinstance(self.data, CipherData): return @@ -234,38 +265,9 @@ def decrypt(self, channel_cipher): self.__data = decrypted_data def as_dict(self, binary=False): - data = self.data - data_type = None - encoding = self._encoding_array[:] - - if isinstance(data, (dict, list)): - encoding.append('json') - data = json.dumps(data) - data = str(data) - elif isinstance(data, str) and not binary: - pass - elif not binary and isinstance(data, (bytearray, bytes)): - data = base64.b64encode(data).decode('ascii') - encoding.append('base64') - elif isinstance(data, CipherData): - encoding.append(data.encoding_str) - data_type = data.type - if not binary: - data = base64.b64encode(data.buffer).decode('ascii') - encoding.append('base64') - else: - data = data.buffer - elif binary and isinstance(data, bytearray): - data = bytes(data) - - if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): - raise AblyException("Invalid data payload", 400, 40011) - request_body = { 'name': self.name, - 'data': data, 'timestamp': self.timestamp or None, - 'type': data_type or None, 'clientId': self.client_id or None, 'id': self.id or None, 'connectionId': self.connection_id or None, @@ -274,11 +276,10 @@ def as_dict(self, binary=False): 'version': self.version.as_dict() if self.version else None, 'serial': self.serial, 'action': int(self.action) if self.action is not None else None, + 'annotations': self.annotations.as_dict() if self.annotations else None, + **encode_data(self.data, self._encoding_array, binary), } - if encoding: - request_body['encoding'] = '/'.join(encoding).strip('/') - # None values aren't included request_body = {k: v for k, v in request_body.items() if v is not None} @@ -320,6 +321,31 @@ def from_encoded(obj, cipher=None, context=None): # TM2s version = MessageVersion(serial=serial, timestamp=timestamp) + # Parse annotations from the wire format + annotations_obj = obj.get('annotations') + if annotations_obj is None: + # TM2u: Always initialize annotations with empty summary + annotations = MessageAnnotations() + else: + annotations = MessageAnnotations.from_dict(annotations_obj) + + # Process annotation summary entries to ensure clipped fields are set + if annotations and annotations.summary: + for annotation_type, summary_entry in annotations.summary.items(): + # TM7c1c, TM7d1c: For distinct.v1, unique.v1, multiple.v1 + if (annotation_type.endswith(':distinct.v1') or + annotation_type.endswith(':unique.v1') or + annotation_type.endswith(':multiple.v1')): + # These types have entries that need clipped field + if isinstance(summary_entry, dict): + for _entry_key, entry_value in summary_entry.items(): + if isinstance(entry_value, dict) and 'clipped' not in entry_value: + entry_value['clipped'] = False + # TM7c1c: For flag.v1 + elif annotation_type.endswith(':flag.v1'): + if isinstance(summary_entry, dict) and 'clipped' not in summary_entry: + summary_entry['clipped'] = False + return Message( id=id, name=name, @@ -330,6 +356,7 @@ def from_encoded(obj, cipher=None, context=None): serial=serial, action=action, version=version, + annotations=annotations, **decoded_data ) diff --git a/ably/types/presence.py b/ably/types/presence.py index 723ceacc..7d1a3c05 100644 --- a/ably/types/presence.py +++ b/ably/types/presence.py @@ -1,5 +1,3 @@ -import base64 -import json from datetime import datetime, timedelta from urllib import parse @@ -7,7 +5,7 @@ from ably.types.mixins import EncodeDataMixin from ably.types.typedbuffer import TypedBuffer from ably.util.crypto import CipherData -from ably.util.exceptions import AblyException +from ably.util.encoding import encode_data def _ms_since_epoch(dt): @@ -151,36 +149,10 @@ def to_encoded(self, binary=False): Handles proper encoding of data including JSON serialization, base64 encoding for binary data, and encryption support. """ - data = self.data - data_type = None - encoding = self._encoding_array[:] - - # Handle different data types and build encoding string - if isinstance(data, (dict, list)): - encoding.append('json') - data = json.dumps(data) - data = str(data) - elif isinstance(data, str) and not binary: - pass - elif not binary and isinstance(data, (bytearray, bytes)): - data = base64.b64encode(data).decode('ascii') - encoding.append('base64') - elif isinstance(data, CipherData): - encoding.append(data.encoding_str) - data_type = data.type - if not binary: - data = base64.b64encode(data.buffer).decode('ascii') - encoding.append('base64') - else: - data = data.buffer - elif binary and isinstance(data, bytearray): - data = bytes(data) - - if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): - raise AblyException("Invalid data payload", 400, 40011) result = { 'action': self.action, + **encode_data(self.data, self._encoding_array, binary), } if self.id: @@ -189,12 +161,6 @@ def to_encoded(self, binary=False): result['clientId'] = self.client_id if self.connection_id: result['connectionId'] = self.connection_id - if data is not None: - result['data'] = data - if data_type: - result['type'] = data_type - if encoding: - result['encoding'] = '/'.join(encoding).strip('/') if self.extras: result['extras'] = self.extras if self.timestamp: diff --git a/ably/util/encoding.py b/ably/util/encoding.py new file mode 100644 index 00000000..5187aec2 --- /dev/null +++ b/ably/util/encoding.py @@ -0,0 +1,38 @@ +import base64 +import json +from typing import Any + +from ably.util.crypto import CipherData +from ably.util.exceptions import AblyException + + +def encode_data(data: Any, encoding_array: list, binary: bool = False): + encoding = encoding_array[:] + + if isinstance(data, (dict, list)): + encoding.append('json') + data = json.dumps(data) # json.dumps already returns str + elif isinstance(data, str) and not binary: + pass + elif not binary and isinstance(data, (bytearray, bytes)): + data = base64.b64encode(data).decode('ascii') + encoding.append('base64') + elif isinstance(data, CipherData): + encoding.append(data.encoding_str) + if not binary: + data = base64.b64encode(data.buffer).decode('ascii') + encoding.append('base64') + else: + data = data.buffer + elif binary and isinstance(data, bytearray): + data = bytes(data) + + result = { 'data': data } + + if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): + raise AblyException("Invalid data payload", 400, 40011) + + if encoding: + result['encoding'] = '/'.join(encoding).strip('/') + + return result diff --git a/ably/util/helper.py b/ably/util/helper.py index 53226f27..a35ebe6e 100644 --- a/ably/util/helper.py +++ b/ably/util/helper.py @@ -98,3 +98,13 @@ def validate_message_size(encoded_messages: list, use_binary_protocol: bool, max 400, 40009, ) + +def to_text(value): + if value is None: + return value + elif isinstance(value, str): + return value + elif isinstance(value, bytes): + return value.decode() + else: + raise TypeError(f"expected string or bytes, not {type(value)}") diff --git a/test/ably/realtime/realtimeannotations_test.py b/test/ably/realtime/realtimeannotations_test.py new file mode 100644 index 00000000..a82b6b2b --- /dev/null +++ b/test/ably/realtime/realtimeannotations_test.py @@ -0,0 +1,343 @@ +import asyncio +import logging +import random +import string + +import pytest + +from ably.types.annotation import Annotation, AnnotationAction +from ably.types.channelmode import ChannelMode +from ably.types.channeloptions import ChannelOptions +from ably.types.message import MessageAction +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase, ReusableFuture, assert_waiter + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) +class TestRealtimeAnnotations(BaseAsyncTestCase): + + @pytest.fixture(autouse=True) + async def setup(self, transport): + self.test_vars = await TestApp.get_test_vars() + + client_id = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + self.realtime_client = await TestApp.get_ably_realtime( + use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, + ) + self.rest_client = await TestApp.get_ably_rest( + use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, + ) + + async def test_publish_and_subscribe_annotations(self): + """RTAN1/RTAN4: Publish and subscribe to annotations via realtime and REST""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel_name = self.get_channel_name('mutable:publish_and_subscribe_annotations') + channel = self.realtime_client.channels.get( + channel_name, + channel_options, + ) + rest_channel = self.rest_client.channels.get(channel_name) + await channel.attach() + + # Setup annotation listener + annotation_future = asyncio.Future() + + async def on_annotation(annotation): + if not annotation_future.done(): + annotation_future.set_result(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish a message + publish_result = await channel.publish('message', 'foobar') + + # Reset for next message (summary) + message_summary = asyncio.Future() + + def on_message(msg): + if not message_summary.done(): + message_summary.set_result(msg) + + await channel.subscribe('message', on_message) + + # Publish annotation using realtime + await channel.annotations.publish(publish_result.serials[0], Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + # Wait for annotation + annotation = await annotation_future + assert annotation.action == AnnotationAction.ANNOTATION_CREATE + assert annotation.message_serial == publish_result.serials[0] + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '👍' + assert annotation.serial > annotation.message_serial + + # Wait for summary message + summary = await message_summary + assert summary.action == MessageAction.MESSAGE_SUMMARY + assert summary.serial == publish_result.serials[0] + assert summary.annotations.summary['reaction:distinct.v1']['👍']['total'] == 1 + + # Try again but with REST publish + annotation_future2 = asyncio.Future() + + async def on_annotation2(annotation): + if not annotation_future2.done(): + annotation_future2.set_result(annotation) + + await channel.annotations.subscribe(on_annotation2) + + await rest_channel.annotations.publish(publish_result.serials[0], Annotation( + type='reaction:distinct.v1', + name='😕' + )) + + annotation = await annotation_future2 + assert annotation.action == AnnotationAction.ANNOTATION_CREATE + assert annotation.message_serial == publish_result.serials[0] + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '😕' + assert annotation.serial > annotation.message_serial + + async def test_get_all_annotations_for_a_message(self): + """RTAN3: Retrieve all annotations for a message""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:get_all_annotations_for_a_message'), + channel_options + ) + await channel.attach() + + # Publish a message + publish_result = await channel.publish('message', 'foobar') + + # Publish multiple annotations + emojis = ['👍', '😕', '👎'] + for emoji in emojis: + await channel.annotations.publish(publish_result.serials[0], Annotation( + type='reaction:distinct.v1', + name=emoji + )) + + # Wait for all annotations to appear + annotations = [] + + async def check_annotations(): + nonlocal annotations + res = await channel.annotations.get(publish_result.serials[0], {}) + annotations = res.items + return len(annotations) == 3 + + await assert_waiter(check_annotations, timeout=10) + + # Verify annotations + assert annotations[0].action == AnnotationAction.ANNOTATION_CREATE + assert annotations[0].message_serial == publish_result.serials[0] + assert annotations[0].type == 'reaction:distinct.v1' + assert annotations[0].name == '👍' + assert annotations[1].name == '😕' + assert annotations[2].name == '👎' + assert annotations[1].serial > annotations[0].serial + assert annotations[2].serial > annotations[1].serial + + async def test_subscribe_by_annotation_type(self): + """RTAN4c: Subscribe to annotations filtered by type""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:subscribe_by_type'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + # Subscribe to specific annotation type + reaction_future = asyncio.Future() + + async def on_reaction(annotation): + if not reaction_future.done(): + reaction_future.set_result(annotation) + + await channel.annotations.subscribe('reaction:distinct.v1', on_reaction) + + # Publish message and annotation + publish_result = await channel.publish('message', 'test') + + await channel.annotations.publish(publish_result.serials[0], Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + # Should receive the annotation + annotation = await reaction_future + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '👍' + + async def test_unsubscribe_annotations(self): + """RTAN5: Unsubscribe from annotation events""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:unsubscribe_annotations'), + channel_options + ) + await channel.attach() + + annotations_received = [] + annotation_future = ReusableFuture() + + async def on_annotation(annotation): + annotations_received.append(annotation) + annotation_future.set_result(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish message and first annotation + publish_result = await channel.publish('message', 'test') + + await channel.annotations.publish(publish_result.serials[0], Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + # Wait for the first annotation to appear + await annotation_future.get() + assert len(annotations_received) == 1 + + # Unsubscribe + channel.annotations.unsubscribe(on_annotation) + + await channel.annotations.subscribe(lambda annotation: annotation_future.set_result(annotation)) + + # Publish another annotation + await channel.annotations.publish(publish_result.serials[0], Annotation( + type='reaction:distinct.v1', + name='😕' + )) + + # Wait for the second annotation to appear in another listener + await annotation_future.get() + + assert len(annotations_received) == 1 + + async def test_delete_annotation(self): + """RTAN2: Delete an annotation via realtime""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:delete_annotation'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + annotations_received = [] + annotation_future = ReusableFuture() + async def on_annotation(annotation): + annotations_received.append(annotation) + annotation_future.set_result(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish message and annotation + await channel.publish('message', 'test') + message = await message_future + + await channel.annotations.publish(message.serial, Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + await annotation_future.get() + + # Wait for create annotation + assert len(annotations_received) == 1 + assert annotations_received[0].action == AnnotationAction.ANNOTATION_CREATE + + # Delete the annotation + await channel.annotations.delete(message.serial, Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + # Wait for delete annotation + await annotation_future.get() + + assert len(annotations_received) == 2 + assert annotations_received[1].action == AnnotationAction.ANNOTATION_DELETE + + async def test_subscribe_without_annotation_mode_warns(self, caplog): + """RTAN4e: Subscribing without ANNOTATION_SUBSCRIBE mode logs a warning. + + Per spec, the library should log a warning indicating that the user has tried + to add an annotation listener without having requested the ANNOTATION_SUBSCRIBE + channel mode. + """ + # Create channel without annotation_subscribe mode + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:no_annotation_mode'), + channel_options + ) + await channel.attach() + + async def on_annotation(annotation): + pass + + # RTAN4e: Should log a warning (not raise), and still register the listener + with caplog.at_level(logging.WARNING, logger='ably.realtime.annotations'): + await channel.annotations.subscribe(on_annotation) + + # Verify warning was logged mentioning the missing mode + assert any('ANNOTATION_SUBSCRIBE' in record.message for record in caplog.records) + + # Listener should still be registered (subscribe didn't fail) + # Unsubscribe to clean up + channel.annotations.unsubscribe(on_annotation) diff --git a/test/ably/realtime/realtimeconnection_test.py b/test/ably/realtime/realtimeconnection_test.py index b38c5aaf..f1eb9003 100644 --- a/test/ably/realtime/realtimeconnection_test.py +++ b/test/ably/realtime/realtimeconnection_test.py @@ -369,7 +369,7 @@ async def test_connection_client_id_query_params(self): ably = await TestApp.get_ably_realtime(client_id=client_id) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) - assert ably.connection.connection_manager.transport.params["client_id"] == client_id + assert ably.connection.connection_manager.transport.params["clientId"] == client_id assert ably.auth.client_id == client_id await ably.close() diff --git a/test/ably/rest/restannotations_test.py b/test/ably/rest/restannotations_test.py new file mode 100644 index 00000000..fcf2c696 --- /dev/null +++ b/test/ably/rest/restannotations_test.py @@ -0,0 +1,203 @@ +import logging +import random +import string + +import pytest + +from ably import AblyException +from ably.types.annotation import Annotation, AnnotationAction +from ably.types.message import Message +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase, assert_waiter + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) +class TestRestAnnotations(BaseAsyncTestCase): + + @pytest.fixture(autouse=True) + async def setup(self, transport): + self.test_vars = await TestApp.get_test_vars() + client_id = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + self.ably = await TestApp.get_ably_rest( + use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, + ) + + async def test_publish_annotation_success(self): + """Test successfully publishing an annotation on a message""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_publish_test')] + + # First publish a message + result = await channel.publish('test-event', 'test data') + assert result.serials is not None + assert len(result.serials) > 0 + serial = result.serials[0] + + # Publish an annotation + await channel.annotations.publish(serial, Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + annotations_result = None + + # Wait for annotations to appear + async def check_annotations(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) == 1 + + await assert_waiter(check_annotations, timeout=10) + + # Get annotations to verify + annotations = annotations_result.items + assert len(annotations) >= 1 + assert annotations[0].message_serial == serial + assert annotations[0].type == 'reaction:distinct.v1' + assert annotations[0].name == '👍' + + async def test_publish_annotation_with_message_object(self): + """Test publishing an annotation using a Message object""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_publish_msg_obj')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Create a message object + message = Message(serial=serial) + + # Publish annotation with message object + await channel.annotations.publish(message, Annotation( + type='reaction:distinct.v1', + name='😕' + )) + + annotations_result = None + + # Wait for annotations to appear + async def check_annotations(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) == 1 + + await assert_waiter(check_annotations, timeout=10) + + # Verify + annotations_result = await channel.annotations.get(serial) + annotations = annotations_result.items + assert len(annotations) >= 1 + assert annotations[0].name == '😕' + + async def test_publish_annotation_without_serial_fails(self): + """Test that publishing without a serial raises an exception""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_no_serial')] + + with pytest.raises(AblyException) as exc_info: + await channel.annotations.publish(None, Annotation(type='reaction', name='👍')) + + assert exc_info.value.status_code == 400 + assert exc_info.value.code == 40003 + + async def test_delete_annotation_success(self): + """Test successfully deleting an annotation""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_delete_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish an annotation + await channel.annotations.publish(serial, Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + annotations_result = None + + # Wait for annotation to appear + async def check_annotation(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) >= 1 + + await assert_waiter(check_annotation, timeout=10) + + # Delete the annotation + await channel.annotations.delete(serial, Annotation( + type='reaction:distinct.v1', + name='👍' + )) + + # Wait for annotation to appear + async def check_deleted_annotation(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) >= 2 + + await assert_waiter(check_deleted_annotation, timeout=10) + assert annotations_result.items[-1].type == 'reaction:distinct.v1' + assert annotations_result.items[-1].action == AnnotationAction.ANNOTATION_DELETE + + async def test_get_all_annotations(self): + """Test retrieving all annotations for a message""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_get_all_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish annotations + await channel.annotations.publish(serial, Annotation(type='reaction:distinct.v1', name='👍')) + await channel.annotations.publish(serial, Annotation(type='reaction:distinct.v1', name='😕')) + await channel.annotations.publish(serial, Annotation(type='reaction:distinct.v1', name='👎')) + + # Wait and get all annotations + async def check_annotations(): + res = await channel.annotations.get(serial) + return len(res.items) >= 3 + + await assert_waiter(check_annotations, timeout=10) + + annotations_result = await channel.annotations.get(serial) + annotations = annotations_result.items + assert len(annotations) >= 3 + assert annotations[0].type == 'reaction:distinct.v1' + assert annotations[0].message_serial == serial + # Verify serials are in order + if len(annotations) > 1: + assert annotations[1].serial > annotations[0].serial + if len(annotations) > 2: + assert annotations[2].serial > annotations[1].serial + + async def test_annotation_properties(self): + """Test that annotation properties are correctly set""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_properties_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish annotation with various properties + await channel.annotations.publish(serial, Annotation( + type='reaction:distinct.v1', + name='❤️', + data={'count': 5} + )) + + # Retrieve and verify + async def check_annotation(): + res = await channel.annotations.get(serial) + return len(res.items) > 0 + + await assert_waiter(check_annotation, timeout=10) + + annotations_result = await channel.annotations.get(serial) + annotation = annotations_result.items[0] + assert annotation.message_serial == serial + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '❤️' + assert annotation.serial is not None + assert annotation.serial > serial diff --git a/test/ably/utils.py b/test/ably/utils.py index 09658fc0..ae19e0b5 100644 --- a/test/ably/utils.py +++ b/test/ably/utils.py @@ -229,6 +229,9 @@ def assert_waiter_sync(block: Callable[[], bool], timeout: float = 10) -> None: class WaitableEvent: + """ + Replacement for asyncio.Future that will work with autogenerated sync tests. + """ def __init__(self): self._finished = False @@ -243,3 +246,22 @@ async def wait(self, timeout=10): def finish(self): self._finished = True + +class ReusableFuture: + """ + A reusable future that after each wait() resets itself and wait for the next value. + """ + def __init__(self): + self.__future = asyncio.Future() + + async def get(self, timeout=10): + await asyncio.wait_for(self.__future, timeout=timeout) + self.__future = asyncio.Future() + + def set_result(self, result): + if not self.__future.done(): + self.__future.set_result(result) + + def set_exception(self, exception): + if not self.__future.done(): + self.__future.set_exception(exception) diff --git a/test/unit/annotation_test.py b/test/unit/annotation_test.py new file mode 100644 index 00000000..947ed04e --- /dev/null +++ b/test/unit/annotation_test.py @@ -0,0 +1,319 @@ +"""Unit tests for Annotation type and validation logic. + +Tests cover: +- RSAN1a3: type validation in construct_validate_annotation +- TAN2a: id and connectionId fields on Annotation +- RSAN1c4: idempotent publishing ID format +- RTAN4b: protocol message field population +- RSAN1c1/RSAN2a: explicit action setting in publish/delete +- TAN3: from_encoded / from_encoded_array decoding +- TAN2i: serial-based equality +""" + +import base64 + +import pytest + +from ably.rest.annotations import construct_validate_annotation, serial_from_msg_or_serial +from ably.types.annotation import Annotation, AnnotationAction +from ably.types.message import Message +from ably.util.exceptions import AblyException + +# --- RSAN1a3: type validation --- + +def test_construct_validate_annotation_requires_type(): + """RSAN1a3: Annotation type must be specified""" + annotation = Annotation(name='👍') # No type + with pytest.raises(AblyException) as exc_info: + construct_validate_annotation('serial123', annotation) + assert exc_info.value.status_code == 400 + assert exc_info.value.code == 40000 + assert 'type' in str(exc_info.value).lower() + + +def test_construct_validate_annotation_with_type_succeeds(): + """RSAN1a3: Annotation with type should pass validation""" + annotation = Annotation(type='reaction:distinct.v1', name='👍') + result = construct_validate_annotation('serial123', annotation) + assert result.type == 'reaction:distinct.v1' + assert result.message_serial == 'serial123' + + +def test_construct_validate_annotation_requires_annotation_object(): + """Second argument must be an Annotation instance""" + with pytest.raises(AblyException) as exc_info: + construct_validate_annotation('serial123', 'not_an_annotation') + assert exc_info.value.status_code == 400 + + +def test_serial_from_msg_or_serial_with_string(): + """RSAN1a: Accept string serial""" + assert serial_from_msg_or_serial('abc123') == 'abc123' + + +def test_serial_from_msg_or_serial_with_message(): + """RSAN1a1: Accept Message object with serial""" + msg = Message(serial='abc123') + assert serial_from_msg_or_serial(msg) == 'abc123' + + +def test_serial_from_msg_or_serial_rejects_invalid(): + """RSAN1a: Reject invalid input""" + with pytest.raises(AblyException): + serial_from_msg_or_serial(None) + with pytest.raises(AblyException): + serial_from_msg_or_serial(12345) + + +# --- TAN2a: id field on Annotation --- + +def test_annotation_has_id_field(): + """TAN2a: Annotation must have id field""" + annotation = Annotation(id='test-id-123', type='reaction', name='👍') + assert annotation.id == 'test-id-123' + + +def test_annotation_id_in_as_dict(): + """TAN2a: id should be included in as_dict() output""" + annotation = Annotation(id='test-id', type='reaction', name='👍') + d = annotation.as_dict() + assert d['id'] == 'test-id' + + +def test_annotation_id_from_encoded(): + """TAN2a: id should be read from encoded wire format""" + encoded = { + 'id': 'wire-id-123', + 'type': 'reaction', + 'name': '👍', + 'action': 0, + } + annotation = Annotation.from_encoded(encoded) + assert annotation.id == 'wire-id-123' + + +def test_annotation_id_in_copy_with(): + """TAN2a: id should be preserved/overridden in _copy_with()""" + annotation = Annotation(id='original-id', type='reaction', name='👍') + copy = annotation._copy_with(id='new-id') + assert copy.id == 'new-id' + assert annotation.id == 'original-id' # Original unchanged + + +# --- TAN2a/TAN2c: connectionId field --- + +def test_annotation_has_connection_id(): + """Annotation must have connection_id field""" + annotation = Annotation(connection_id='conn-123', type='reaction', name='👍') + assert annotation.connection_id == 'conn-123' + + +def test_annotation_connection_id_from_encoded(): + """connection_id should be read from encoded wire format""" + encoded = { + 'connectionId': 'conn-456', + 'type': 'reaction', + 'action': 0, + } + annotation = Annotation.from_encoded(encoded) + assert annotation.connection_id == 'conn-456' + + +# --- RSAN1c4: idempotent publishing ID format --- + +def test_idempotent_id_format(): + """RSAN1c4: ID should be base64(9 random bytes) + ':0'""" + # We can't test the actual REST publish without a server, but we can + # verify the format by checking the regex pattern + import os + random_id = base64.b64encode(os.urandom(9)).decode('ascii') + ':0' + # Should be base64 chars followed by ':0' + assert random_id.endswith(':0') + # Base64 of 9 bytes = 12 chars + base64_part = random_id[:-2] + assert len(base64_part) == 12 + # Verify it's valid base64 + decoded = base64.b64decode(base64_part) + assert len(decoded) == 9 + + +# --- RTAN4b: protocol message field population --- + +def test_update_inner_annotation_fields(): + """RTAN4b: Populate annotation fields from protocol message envelope""" + proto_msg = { + 'id': 'proto-msg-id', + 'connectionId': 'conn-abc', + 'timestamp': 1234567890, + 'annotations': [ + {'type': 'reaction', 'name': '👍'}, + {'type': 'reaction', 'name': '👎'}, + ] + } + Annotation.update_inner_annotation_fields(proto_msg) + annotations = proto_msg['annotations'] + + # First annotation + assert annotations[0]['id'] == 'proto-msg-id:0' + assert annotations[0]['connectionId'] == 'conn-abc' + assert annotations[0]['timestamp'] == 1234567890 + + # Second annotation + assert annotations[1]['id'] == 'proto-msg-id:1' + assert annotations[1]['connectionId'] == 'conn-abc' + assert annotations[1]['timestamp'] == 1234567890 + + +def test_update_inner_annotation_fields_preserves_existing(): + """RTAN4b: Don't overwrite existing annotation fields""" + proto_msg = { + 'id': 'proto-msg-id', + 'connectionId': 'conn-abc', + 'timestamp': 1234567890, + 'annotations': [ + { + 'type': 'reaction', + 'id': 'existing-id', + 'connectionId': 'existing-conn', + 'timestamp': 9999999999, + }, + ] + } + Annotation.update_inner_annotation_fields(proto_msg) + annotation = proto_msg['annotations'][0] + + # Existing values should be preserved + assert annotation['id'] == 'existing-id' + assert annotation['connectionId'] == 'existing-conn' + assert annotation['timestamp'] == 9999999999 + + +def test_update_inner_annotation_fields_no_annotations(): + """RTAN4b: Should handle missing annotations gracefully""" + proto_msg = {'id': 'proto-msg-id'} + # Should not raise + Annotation.update_inner_annotation_fields(proto_msg) + + +# --- RSAN1c1/RSAN2a: explicit action setting --- + +def test_annotation_default_action_is_create(): + """Default action should be ANNOTATION_CREATE""" + annotation = Annotation(type='reaction', name='👍') + assert annotation.action == AnnotationAction.ANNOTATION_CREATE + + +def test_annotation_copy_with_action(): + """_copy_with should allow changing action""" + annotation = Annotation(type='reaction', name='👍') + deleted = annotation._copy_with(action=AnnotationAction.ANNOTATION_DELETE) + assert deleted.action == AnnotationAction.ANNOTATION_DELETE + assert annotation.action == AnnotationAction.ANNOTATION_CREATE # Original unchanged + + +# --- TAN3: from_encoded() with None data --- + +def test_from_encoded_with_none_data(): + """from_encoded should handle None data properly""" + encoded = { + 'type': 'reaction', + 'name': '👍', + 'action': 0, + } + annotation = Annotation.from_encoded(encoded) + assert annotation.data is None + assert annotation.type == 'reaction' + + +def test_from_encoded_with_data(): + """from_encoded should decode data when present""" + encoded = { + 'type': 'reaction', + 'name': '👍', + 'action': 0, + 'data': 'hello', + } + annotation = Annotation.from_encoded(encoded) + assert annotation.data == 'hello' + + +def test_from_encoded_with_json_data(): + """from_encoded should decode JSON-encoded data""" + import json + encoded = { + 'type': 'reaction', + 'action': 0, + 'data': json.dumps({'count': 5}), + 'encoding': 'json', + } + annotation = Annotation.from_encoded(encoded) + assert annotation.data == {'count': 5} + + +# --- TAN2i: __eq__ based on serial --- + +def test_annotation_eq_by_serial(): + """TAN2i: Annotations with same serial should be equal""" + a1 = Annotation(serial='s1', type='reaction', name='👍') + a2 = Annotation(serial='s1', type='different', name='👎') + assert a1 == a2 + + +def test_annotation_ne_by_serial(): + """TAN2i: Annotations with different serials should not be equal""" + a1 = Annotation(serial='s1', type='reaction', name='👍') + a2 = Annotation(serial='s2', type='reaction', name='👍') + assert a1 != a2 + + +def test_annotation_eq_fallback_includes_client_id(): + """Fallback equality should include client_id""" + a1 = Annotation(type='reaction', name='👍', client_id='user1', + message_serial='ms1', action=AnnotationAction.ANNOTATION_CREATE) + a2 = Annotation(type='reaction', name='👍', client_id='user2', + message_serial='ms1', action=AnnotationAction.ANNOTATION_CREATE) + assert a1 != a2 # Different client_id + + +def test_annotation_eq_fallback_same_fields(): + """Fallback equality with same fields should be equal""" + a1 = Annotation(type='reaction', name='👍', client_id='user1', + message_serial='ms1', action=AnnotationAction.ANNOTATION_CREATE) + a2 = Annotation(type='reaction', name='👍', client_id='user1', + message_serial='ms1', action=AnnotationAction.ANNOTATION_CREATE) + assert a1 == a2 + + +# --- as_dict serialization --- + +def test_annotation_as_dict_filters_none(): + """as_dict should not include None values""" + annotation = Annotation(type='reaction', name='👍') + d = annotation.as_dict() + assert 'serial' not in d + assert 'extras' not in d + assert 'type' in d + assert 'name' in d + + +def test_annotation_as_dict_includes_action(): + """as_dict should include action as integer""" + annotation = Annotation(type='reaction', name='👍', action=AnnotationAction.ANNOTATION_DELETE) + d = annotation.as_dict() + assert d['action'] == 1 # ANNOTATION_DELETE + + +# --- from_encoded_array --- + +def test_from_encoded_array(): + """from_encoded_array should decode multiple annotations""" + encoded_array = [ + {'type': 'reaction', 'name': '👍', 'action': 0}, + {'type': 'reaction', 'name': '👎', 'action': 1}, + ] + annotations = Annotation.from_encoded_array(encoded_array) + assert len(annotations) == 2 + assert annotations[0].name == '👍' + assert annotations[0].action == AnnotationAction.ANNOTATION_CREATE + assert annotations[1].name == '👎' + assert annotations[1].action == AnnotationAction.ANNOTATION_DELETE diff --git a/uv.lock b/uv.lock index 1b196ab7..5b48323d 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ [[package]] name = "ably" -version = "2.1.3" +version = "3.0.0" source = { editable = "." } dependencies = [ { name = "h2", version = "4.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },