diff --git a/agent-sdk-client/consumer.py b/agent-sdk-client/consumer.py index 5b7e6c8..34c57b8 100644 --- a/agent-sdk-client/consumer.py +++ b/agent-sdk-client/consumer.py @@ -15,6 +15,23 @@ from config import Config +def _get_reply_to_id(message_id: int, thread_id: int | None, message_thread_id: int | None) -> int | None: + """Determine if we should reply to the original message. + + Only reply to the original message if we're in the same thread. + This prevents Telegram API errors when sending to a different thread (e.g., /newchat). + + Args: + message_id: The original message ID + thread_id: The target thread ID (may be overridden by handler) + message_thread_id: The original message's thread ID + + Returns: + message_id if in same thread, None otherwise + """ + return message_id if thread_id == message_thread_id else None + + def lambda_handler(event: dict, context: Any) -> dict: """SQS Consumer Lambda entry point.""" for record in event['Records']: @@ -43,6 +60,8 @@ async def process_message(message_data: dict) -> None: """Process single message from SQS queue.""" import logging logger = logging.getLogger() + # Enable INFO logging as suggested in issue for better debugging + logger.setLevel(logging.INFO) config = Config.from_env() bot = Bot(config.telegram_token) @@ -55,6 +74,11 @@ async def process_message(message_data: dict) -> None: logger.warning("Received update with no message or edited_message") return + # Extract thread_id and user_message early - needed for all message processing + # (allows handler to override text/thread_id via SQS message_data) + user_message = message_data.get('text') or message.text + thread_id = message_data.get('thread_id') or message.message_thread_id + cmd = config.get_command(message.text) if cmd: if config.is_local_command(cmd): @@ -62,12 +86,13 @@ async def process_message(message_data: dict) -> None: "Handling local command in consumer (fallback path)", extra={'chat_id': message.chat_id, 'message_id': message.message_id}, ) + reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id) try: await bot.send_message( chat_id=message.chat_id, text=config.local_response(cmd), - message_thread_id=message.message_thread_id, - reply_to_message_id=message.message_id, + message_thread_id=thread_id, + reply_to_message_id=reply_to_id, ) except Exception: logger.warning("Failed to send local command response", exc_info=True) @@ -82,12 +107,13 @@ async def process_message(message_data: dict) -> None: 'message_id': message.message_id, }, ) + reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id) try: await bot.send_message( chat_id=message.chat_id, text=config.unknown_command_message(), - message_thread_id=message.message_thread_id, - reply_to_message_id=message.message_id, + message_thread_id=thread_id, + reply_to_message_id=reply_to_id, ) except Exception: logger.warning("Failed to send local command response", exc_info=True) @@ -97,7 +123,7 @@ async def process_message(message_data: dict) -> None: await bot.send_chat_action( chat_id=message.chat_id, action=ChatAction.TYPING, - message_thread_id=message.message_thread_id, + message_thread_id=thread_id, ) # Initialize result with default error response @@ -108,10 +134,6 @@ async def process_message(message_data: dict) -> None: 'error_message': 'Failed to get response from Agent Server' } - # Use message_data fields for SQS message (allows handler to override text/thread_id) - user_message = message_data.get('text') or message.text - thread_id = message_data.get('thread_id') or message.message_thread_id - # Call Agent Server try: async with httpx.AsyncClient(timeout=600.0) as client: @@ -132,21 +154,25 @@ async def process_message(message_data: dict) -> None: except httpx.TimeoutException: logger.warning(f"Agent Server timeout for chat_id={message.chat_id}") + reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id) await bot.send_message( chat_id=message.chat_id, text="Request timed out.", message_thread_id=thread_id, + reply_to_message_id=reply_to_id, ) raise # Re-raise to trigger SQS retry for transient errors except Exception as e: logger.exception(f"Agent Server error for chat_id={message.chat_id}") error_text = f"Error: {str(e)[:200]}" + reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id) try: await bot.send_message( chat_id=message.chat_id, text=error_text, message_thread_id=thread_id, + reply_to_message_id=reply_to_id, ) except Exception as send_error: logger.error(f"Failed to send error message to Telegram: {send_error}") @@ -162,23 +188,35 @@ async def process_message(message_data: dict) -> None: text = text[:4000] + "\n\n... (truncated)" # Send response to Telegram + reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id) + try: await bot.send_message( chat_id=message.chat_id, text=text, parse_mode=ParseMode.MARKDOWN_V2, message_thread_id=thread_id, - reply_to_message_id=message.message_id, + reply_to_message_id=reply_to_id, + ) + logger.info( + f"Message sent successfully to chat_id={message.chat_id}, " + f"thread_id={thread_id}, reply_to={reply_to_id}" ) except BadRequest as e: if "parse entities" in str(e).lower(): + logger.warning(f"Markdown parse error, retrying with escaped text: {e}") safe_text = escape_markdown(text, version=2) await bot.send_message( chat_id=message.chat_id, text=safe_text, parse_mode=ParseMode.MARKDOWN_V2, message_thread_id=thread_id, - reply_to_message_id=message.message_id, + reply_to_message_id=reply_to_id, + ) + logger.info( + f"Message sent successfully (escaped) to chat_id={message.chat_id}, " + f"thread_id={thread_id}, reply_to={reply_to_id}" ) else: + logger.error(f"Failed to send message: {e}") raise