diff --git a/actions/DiscordCore.py b/actions/DiscordCore.py index 12e87e0..444ddd7 100644 --- a/actions/DiscordCore.py +++ b/actions/DiscordCore.py @@ -45,14 +45,12 @@ def create_event_assigners(self): def register_backend_callback(self, key: str, callback: callable): """Register a callback and track it for cleanup.""" - self.backend.register_callback(key, callback) self._registered_callbacks.append((key, callback)) self.plugin_base.add_callback(key, callback) def cleanup_callbacks(self): """Unregister all tracked callbacks to prevent memory leaks.""" for key, callback in self._registered_callbacks: - self.backend.unregister_callback(key, callback) self.plugin_base.remove_callback(key, callback) self._registered_callbacks.clear() diff --git a/backend.py b/backend.py index 5bf6241..2b16552 100644 --- a/backend.py +++ b/backend.py @@ -1,4 +1,5 @@ import json +import threading from streamcontroller_plugin_tools import BackendBase @@ -16,24 +17,118 @@ def __init__(self): self.refresh_token: str = None self.discord_client: AsyncDiscord = None self._is_authed: bool = False + self._current_user_id: str = None self._current_voice_channel: str = None - self._is_reconnecting: bool = False self._voice_channel_users: dict = {} # {user_id: {username, nick, volume, muted}} - self._current_user_id: str = None # Current user's ID (for filtering) + self._connecting: bool = False + self._ready: bool = False + self._setup_lock = threading.Lock() + + def _ensure_ready(self): + """ + Single initialization gate. Delegates to setup_client for thread safety. + _ready means socket connected and polling started, not authenticated. + """ + if self._ready and self.discord_client: + return + + if not self.client_id or not self.client_secret: + return + + # Let setup_client handle locking - don't check _connecting here + self.setup_client() + + def setup_client(self): + with self._setup_lock: + # Combined check inside lock to prevent races + if self._connecting or (self._ready and self.discord_client): + return + + self._connecting = True + + try: + # Cleanup old client before creating new one to prevent thread leaks + if self.discord_client: + try: + self.discord_client.disconnect() + except Exception: + pass + + self.discord_client = AsyncDiscord( + self.client_id, + self.client_secret + ) + + self.discord_client.connect(self.discord_callback) + + if self.access_token: + self.discord_client.authenticate(self.access_token) + else: + self.discord_client.authorize() + + # _ready = socket connected & polling. Auth tracked separately via _is_authed + self._ready = True + + except Exception as ex: + log.error(f"setup_client failed: {ex}") + self.discord_client = None + self._ready = False + + finally: + self._connecting = False + + def update_client_credentials( + self, + client_id: str, + client_secret: str, + access_token: str = "", + refresh_token: str = "", + ): + if None in (client_id, client_secret) or "" in (client_id, client_secret): + self.frontend.on_auth_callback( + False, "actions.base.credentials.missing_client_info" + ) + return + + with self._setup_lock: + self.client_id = client_id + self.client_secret = client_secret + self.access_token = access_token + self.refresh_token = refresh_token + self._ready = False + self._is_authed = False + + if self.discord_client: + try: + self.discord_client.disconnect() + except Exception: + pass + self.discord_client = None + + # setup_client acquires its own lock + self.setup_client() def discord_callback(self, code, event): if code == 0: return + + if not event or not isinstance(event, str): + return + try: event = json.loads(event) except Exception as ex: - log.error(f"failed to parse discord event: {ex}") + log.error(f"failed to parse Discord event: {ex}") return + + # Handle Discord-side error codes (session invalidated) resp_code = ( event.get("data").get("code", 0) if event.get("data") is not None else 0 ) if resp_code in [4006, 4009]: if not self.refresh_token: + if self.discord_client: + self.discord_client.disconnect() self.setup_client() return try: @@ -41,6 +136,8 @@ def discord_callback(self, code, event): except Exception as ex: log.error(f"failed to refresh token {ex}") self._update_tokens("", "") + if self.discord_client: + self.discord_client.disconnect() self.setup_client() return access_token = token_resp.get("access_token") @@ -48,34 +145,47 @@ def discord_callback(self, code, event): self._update_tokens(access_token, refresh_token) self.discord_client.authenticate(self.access_token) return - match event.get("cmd"): - case commands.AUTHORIZE: - auth_code = event.get("data").get("code") - token_resp = self.discord_client.get_access_token(auth_code) - self.access_token = token_resp.get("access_token") - self.refresh_token = token_resp.get("refresh_token") - self.discord_client.authenticate(self.access_token) - self.frontend.save_access_token(self.access_token) - self.frontend.save_refresh_token(self.refresh_token) - case commands.AUTHENTICATE: - self.frontend.on_auth_callback(True) - self._is_authed = True - # Capture current user ID for filtering in UserVolume - data = event.get("data", {}) - user = data.get("user", {}) - self._register_callbacks() - self._current_user_id = user.get("id") - self._get_current_voice_channel() - case commands.DISPATCH: - evt = event.get("evt") - self.frontend.trigger_event(evt, event.get("data")) - case commands.GET_SELECTED_VOICE_CHANNEL: - self._current_voice_channel = ( - event.get("data").get("channel_id") if event.get("data") else None - ) - self.frontend.trigger_event(commands.VOICE_CHANNEL_SELECT, event.get("data")) - case commands.GET_CHANNEL: - self.frontend.trigger_event(commands.GET_CHANNEL, event.get("data")) + + cmd = event.get("cmd") + data = event.get("data") or {} + + if cmd == commands.AUTHORIZE: + token = self.discord_client.get_access_token(data.get("code")) + + self.access_token = token.get("access_token") + self.refresh_token = token.get("refresh_token") + + self.frontend.save_access_token(self.access_token) + self.frontend.save_refresh_token(self.refresh_token) + + self.discord_client.authenticate(self.access_token) + + elif cmd == commands.AUTHENTICATE: + self._is_authed = True + self.frontend.on_auth_callback(True) + + self._current_user_id = data.get("user", {}).get("id") + + self._register_callbacks() + self._get_current_voice_channel() + + elif cmd == commands.DISPATCH: + self.frontend.trigger_event(event.get("evt"), data) + + elif cmd == commands.GET_SELECTED_VOICE_CHANNEL: + self._current_voice_channel = ( + data.get("channel_id") if data else None + ) + self.frontend.trigger_event(commands.VOICE_CHANNEL_SELECT, data) + + elif cmd == commands.GET_CHANNEL: + self.frontend.trigger_event(commands.GET_CHANNEL, data) + + elif cmd == commands.VOICE_SETTINGS_UPDATE: + self.frontend.trigger_event( + commands.VOICE_SETTINGS_UPDATE, + data + ) def _update_tokens(self, access_token: str = "", refresh_token: str = ""): self.access_token = access_token @@ -83,120 +193,60 @@ def _update_tokens(self, access_token: str = "", refresh_token: str = ""): self.frontend.save_access_token(access_token) self.frontend.save_refresh_token(refresh_token) - def setup_client(self): - if self._is_reconnecting: - log.debug("Already reconnecting, skipping duplicate attempt") - return - try: - self._is_reconnecting = True - self.discord_client = AsyncDiscord(self.client_id, self.client_secret) - self.discord_client.connect(self.discord_callback) - if not self.access_token: - self.discord_client.authorize() - else: - self.discord_client.authenticate(self.access_token) - except Exception as ex: - self.frontend.on_auth_callback(False, str(ex)) - log.error("failed to setup discord client: {0}", ex) - if self.discord_client: - self.discord_client.disconnect() - self.discord_client = None - finally: - self._is_reconnecting = False + def _ensure_connected(self) -> bool: + """Ensure client is connected, attempt reconnection if needed.""" + self._ensure_ready() - def update_client_credentials( - self, - client_id: str, - client_secret: str, - access_token: str = "", - refresh_token: str = "", - ): - if None in (client_id, client_secret) or "" in (client_id, client_secret): - self.frontend.on_auth_callback( - False, "actions.base.credentials.missing_client_info" - ) - return - self.client_id = client_id - self.client_secret = client_secret - self.access_token = access_token - self.refresh_token = refresh_token - self.setup_client() + if self.discord_client and self.discord_client.is_connected(): + return True - def is_authed(self) -> bool: - return self._is_authed + # If we have credentials, attempt reconnection (setup_client handles locking) + if self.client_id and self.client_secret: + log.debug("Discord disconnected, attempting reconnect") + self.setup_client() + return self.discord_client and self.discord_client.is_connected() - def _register_callbacks(self): - self.discord_client.subscribe(commands.VOICE_SETTINGS_UPDATE) - self.discord_client.subscribe(commands.VOICE_CHANNEL_SELECT) - self.discord_client.subscribe(commands.GET_CHANNEL) - - def _ensure_connected(self) -> bool: - """Ensure client is connected, trigger reconnection if needed.""" - if self.discord_client is None or not self.discord_client.is_connected(): - if not self._is_reconnecting: - self.setup_client() - return False - return True + return False def set_mute(self, muted: bool): if not self._ensure_connected(): - log.warning("Discord client not connected, cannot set mute") + log.warning("Cannot set mute: Discord not connected") return - self.discord_client.set_voice_settings({"mute": muted}) - - def set_deafen(self, muted: bool): - if not self._ensure_connected(): - log.warning("Discord client not connected, cannot set deafen") + if not self._is_authed: + log.warning("Cannot set mute: Discord not authenticated") return - self.discord_client.set_voice_settings({"deaf": muted}) + self.discord_client.set_voice_settings({"mute": bool(muted)}) - def change_voice_channel(self, channel_id: str = None) -> bool: + def set_deafen(self, deafened: bool): if not self._ensure_connected(): - log.warning("Discord client not connected, cannot change voice channel") - return False - self.discord_client.select_voice_channel(channel_id, True) - return True - - def change_text_channel(self, channel_id: str) -> bool: - if not self._ensure_connected(): - log.warning("Discord client not connected, cannot change text channel") - return False - self.discord_client.select_text_channel(channel_id) - return True - - def set_push_to_talk(self, ptt: str) -> bool: - if not self._ensure_connected(): - log.warning("Discord client not connected, cannot set push to talk") - return False - self.discord_client.set_voice_settings({"mode": {"type": ptt}}) - return True + log.warning("Cannot set deafen: Discord not connected") + return + if not self._is_authed: + log.warning("Cannot set deafen: Discord not authenticated") + return + self.discord_client.set_voice_settings({"deaf": bool(deafened)}) - @property - def current_voice_channel(self): - return self._current_voice_channel + def is_authed(self) -> bool: + return self._is_authed - @property - def current_user_id(self): - return self._current_user_id + def _register_callbacks(self): + self.discord_client.subscribe(commands.VOICE_SETTINGS_UPDATE) + self.discord_client.subscribe(commands.VOICE_CHANNEL_SELECT) + self.discord_client.subscribe(commands.GET_CHANNEL) def _get_current_voice_channel(self): - if not self._ensure_connected(): - log.warning( - "Discord client not connected, cannot get current voice channel" - ) - return - self.discord_client.get_selected_voice_channel() - - def request_current_voice_channel(self): - """Public method to request current voice channel state (dispatches to callbacks).""" - self._get_current_voice_channel() + if self.discord_client: + self.discord_client.get_selected_voice_channel() # User volume control methods def set_user_volume(self, user_id: str, volume: int) -> bool: """Set volume for a specific user (0-200, 100 = normal).""" if not self._ensure_connected(): - log.warning("Discord client not connected, cannot set user volume") + log.warning("Cannot set user volume: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot set user volume: Discord not authenticated") return False self.discord_client.set_user_voice_settings(user_id, volume=volume) if user_id in self._voice_channel_users: @@ -206,7 +256,10 @@ def set_user_volume(self, user_id: str, volume: int) -> bool: def set_user_mute(self, user_id: str, muted: bool) -> bool: """Mute/unmute a specific user locally.""" if not self._ensure_connected(): - log.warning("Discord client not connected, cannot set user mute") + log.warning("Cannot set user mute: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot set user mute: Discord not authenticated") return False self.discord_client.set_user_voice_settings(user_id, mute=muted) if user_id in self._voice_channel_users: @@ -238,7 +291,10 @@ def get_voice_channel_users(self) -> dict: def get_channel(self, channel_id: str) -> bool: """Fetch channel information including voice states.""" if not self._ensure_connected(): - log.warning("Discord client not connected, cannot get channel") + log.warning("Cannot get channel: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot get channel: Discord not authenticated") return False self.discord_client.get_channel(channel_id) return True @@ -246,7 +302,10 @@ def get_channel(self, channel_id: str) -> bool: def subscribe_voice_states(self, channel_id: str) -> bool: """Subscribe to voice state events for a specific channel.""" if not self._ensure_connected(): - log.warning("Discord client not connected, cannot subscribe to voice states") + log.warning("Cannot subscribe to voice states: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot subscribe to voice states: Discord not authenticated") return False args = {"channel_id": channel_id} self.discord_client.subscribe(commands.VOICE_STATE_CREATE, args) @@ -257,6 +316,10 @@ def subscribe_voice_states(self, channel_id: str) -> bool: def unsubscribe_voice_states(self, channel_id: str) -> bool: """Unsubscribe from voice state events for a specific channel.""" if not self._ensure_connected(): + log.warning("Cannot unsubscribe from voice states: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot unsubscribe from voice states: Discord not authenticated") return False args = {"channel_id": channel_id} self.discord_client.unsubscribe(commands.VOICE_STATE_CREATE, args) @@ -264,14 +327,58 @@ def unsubscribe_voice_states(self, channel_id: str) -> bool: self.discord_client.unsubscribe(commands.VOICE_STATE_UPDATE, args) return True + def change_voice_channel(self, channel_id: str = None) -> bool: + if not self._ensure_connected(): + log.warning("Cannot change voice channel: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot change voice channel: Discord not authenticated") + return False + self.discord_client.select_voice_channel(channel_id, True) + return True + + def change_text_channel(self, channel_id: str) -> bool: + if not self._ensure_connected(): + log.warning("Cannot change text channel: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot change text channel: Discord not authenticated") + return False + self.discord_client.select_text_channel(channel_id) + return True + + def set_push_to_talk(self, ptt: str) -> bool: + if not self._ensure_connected(): + log.warning("Cannot set push to talk: Discord not connected") + return False + if not self._is_authed: + log.warning("Cannot set push to talk: Discord not authenticated") + return False + self.discord_client.set_voice_settings({"mode": {"type": ptt}}) + return True + + def request_current_voice_channel(self): + """Public method to request current voice channel state (dispatches to callbacks).""" + self._get_current_voice_channel() + + @property + def current_voice_channel(self): + return self._current_voice_channel + + @property + def current_user_id(self): + return self._current_user_id + def close(self): if self.discord_client: try: self.discord_client.disconnect() - except Exception as ex: - log.error(f"Error disconnecting Discord client: {ex}") - self.discord_client = None + except Exception: + pass + + self.discord_client = None self._is_authed = False + self._ready = False backend = Backend() diff --git a/discordrpc/asyncdiscord.py b/discordrpc/asyncdiscord.py index 242a299..72c7a1c 100644 --- a/discordrpc/asyncdiscord.py +++ b/discordrpc/asyncdiscord.py @@ -26,6 +26,7 @@ def __init__(self, client_id: str, client_secret: str, access_token: str = ""): self.client_secret = client_secret self.access_token = access_token self.polling = False + self._polling_thread = None self._session = requests.Session() # Reuse HTTP connections def _send_rpc_command(self, command: str, args: dict = None): @@ -35,7 +36,7 @@ def _send_rpc_command(self, command: str, args: dict = None): self.rpc.send(payload, OP_FRAME) def is_connected(self): - return self.polling + return self.polling and self.rpc.socket is not None def connect(self, callback: callable): tries = 0 @@ -71,10 +72,13 @@ def connect(self, callback: callable): if data.get("cmd") != "DISPATCH" or data.get("evt") != "READY": raise RPCException self.polling = True - threading.Thread(target=self.poll_callback, args=[callback]).start() + self._polling_thread = threading.Thread(target=self.poll_callback, args=[callback]) + self._polling_thread.start() def disconnect(self): self.polling = False + if self._polling_thread and self._polling_thread.is_alive(): + self._polling_thread.join(timeout=3) self.rpc.disconnect() if self._session: self._session.close() @@ -88,10 +92,16 @@ def poll_callback(self, callback: callable): except Exception as ex: log.error(f"error receiving data from socket. {ex}") self.disconnect() + continue + if val[0] == SOCKET_BAD_BUFFER_SIZE: log.debug("bad buffer size when receiving data from socket") + continue + if val[0] == SOCKET_DISCONNECTED: self.disconnect() + break + callback(val[0], val[1]) def authorize(self): diff --git a/discordrpc/constants.py b/discordrpc/constants.py index b5e1a1b..08cf244 100644 --- a/discordrpc/constants.py +++ b/discordrpc/constants.py @@ -5,6 +5,4 @@ MAX_IPC_SOCKET_RANGE = ( 10 # Number of IPC sockets to try (discord-ipc-0 through discord-ipc-9) ) -SOCKET_SELECT_TIMEOUT = 0.1 # Socket select timeout in seconds (reduced from 1.0s for 90% latency improvement) -#SOCKET_BUFFER_SIZE = 1024 # Socket receive buffer size in bytes -SOCKET_BUFFER_SIZE = 8 +SOCKET_RECEIVE_TIMEOUT = 0.5 diff --git a/discordrpc/sockets.py b/discordrpc/sockets.py index 198af84..9867db8 100644 --- a/discordrpc/sockets.py +++ b/discordrpc/sockets.py @@ -3,18 +3,17 @@ import struct import json import re -import select from loguru import logger as log from .exceptions import DiscordNotOpened -from .constants import MAX_IPC_SOCKET_RANGE, SOCKET_SELECT_TIMEOUT, SOCKET_BUFFER_SIZE +from .constants import MAX_IPC_SOCKET_RANGE, SOCKET_RECEIVE_TIMEOUT SOCKET_DISCONNECTED: int = -1 SOCKET_BAD_BUFFER_SIZE: int = -2 SOCKET_SEND_TIMEOUT: int = 5 SOCKET_CONNECT_TIMEOUT: int = 2 -SOCKET_RECEIVE_TIMEOUT: int = 10 + class UnixPipe: def __init__(self): @@ -22,70 +21,111 @@ def __init__(self): def connect(self): if self.socket is not None: - log.debug("Socket already connected, disconnecting first.") self.disconnect() + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.socket.settimeout(SOCKET_CONNECT_TIMEOUT) - base_path = path = ( + + base_path = ( os.environ.get("XDG_RUNTIME_DIR") or os.environ.get("TMPDIR") or os.environ.get("TMP") or os.environ.get("TEMP") or "/tmp" ) - base_path = re.sub(r"\/$", "", path) + "/discord-ipc-{0}" + + base_path = re.sub(r"\/$", "", base_path) + "/discord-ipc-{0}" + for i in range(MAX_IPC_SOCKET_RANGE): path = base_path.format(i) try: - log.debug(f"Attempting to connect to socket at path: {path}") + log.debug(f"Trying Discord IPC socket: {path}") self.socket.connect(path) break except FileNotFoundError: - log.warning(f"socket {path} not found, trying next socket.") - pass + continue except Exception as ex: - log.error( - f"failed to connect to socket {path}, trying next socket. {ex}" - ) - # Skip all errors to try all sockets - pass + log.debug(f"Socket connect failed {path}: {ex}") else: raise DiscordNotOpened - log.debug(f"Connected to socket at path: {path}") - self.socket.setblocking(False) + + log.debug("Connected to Discord IPC socket") + self.socket.setblocking(True) + self.socket.settimeout(SOCKET_RECEIVE_TIMEOUT) def disconnect(self): if self.socket is None: return + try: self.socket.shutdown(socket.SHUT_RDWR) - except OSError as ex: - # Socket might already be disconnected - log.debug(f"Socket shutdown error (already disconnected): {ex}") + except Exception: + pass + try: self.socket.close() - except OSError as ex: - log.debug(f"Socket close error: {ex}") - self.socket = None # Reset so connect() creates a fresh socket + except Exception: + pass + + self.socket = None def send(self, payload, op): - payload_bytes = json.dumps(payload).encode("UTF-8") + payload_bytes = json.dumps(payload).encode("utf-8") header = struct.pack(" (int, str): - data = self.socket.recv(SOCKET_BUFFER_SIZE) - if len(data) == 0: - return SOCKET_DISCONNECTED, {} - header = data[:8] - code = int.from_bytes(header[:4], "little") - length = int.from_bytes(header[4:], "little") - all_data = b"" - if length < 0: - return SOCKET_BAD_BUFFER_SIZE, {} - if length > 0: - data = self.socket.recv(length) - all_data += data - return code, all_data.decode("UTF-8") + + orig_timeout = self.socket.gettimeout() + try: + self.socket.settimeout(SOCKET_SEND_TIMEOUT) + self.socket.sendall(message) + finally: + self.socket.settimeout(orig_timeout) + + def receive(self) -> tuple[int | None, str | None]: + try: + header = self._recv_exact(8) + + code = int.from_bytes(header[:4], "little") + length = int.from_bytes(header[4:], "little") + + if length < 0: + return SOCKET_BAD_BUFFER_SIZE, "" + + payload = self._recv_exact(length) + + return code, payload.decode("utf-8") + + except EOFError: + log.debug("Discord IPC connection closed or idle timeout") + return SOCKET_DISCONNECTED, "" + + except (OSError, socket.error) as ex: + log.error(f"Fatal socket error: {ex}") + return SOCKET_DISCONNECTED, "" + + except Exception as ex: + log.error(f"Unexpected receive error: {ex}") + return SOCKET_DISCONNECTED, "" + + def _recv_exact(self, size: int): + """Read exactly size bytes or raise EOFError.""" + data = b"" + + while len(data) < size: + try: + chunk = self.socket.recv(size - len(data)) + + # Peer closed the connection cleanly + if not chunk: + raise EOFError("Discord IPC socket closed") + + data += chunk + + except socket.timeout: + raise EOFError("Timeout during socket read - stream idle or corrupted") + + except Exception as ex: + log.debug(f"_recv_exact socket error: {ex}") + raise EOFError(f"Socket error during read: {ex}") + + return data diff --git a/main.py b/main.py index 7c3d5aa..2ddb9bb 100644 --- a/main.py +++ b/main.py @@ -242,6 +242,10 @@ def clear_callbacks(self, key: str, callback: callable): else: del self.callbacks[key] + def remove_callback(self, key: str, callback: callable): + """Remove a specific callback. Delegates to clear_callbacks.""" + self.clear_callbacks(key, callback) + def trigger_event(self, event_id_suffix: str, data: any): event_id = f"{self.get_plugin_id()}::{event_id_suffix}" if not event_id in self.event_holders: