Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions agent-sdk-client/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@
from config import Config


def _get_reply_to_id(message_id: int, thread_id: int | None, message_thread_id: int | None) -> int | None:
"""Determine if we should reply to the original message.

Only reply to the original message if we're in the same thread.
This prevents Telegram API errors when sending to a different thread (e.g., /newchat).

Args:
message_id: The original message ID
thread_id: The target thread ID (may be overridden by handler)
message_thread_id: The original message's thread ID

Returns:
message_id if in same thread, None otherwise
"""
return message_id if thread_id == message_thread_id else None


def lambda_handler(event: dict, context: Any) -> dict:
"""SQS Consumer Lambda entry point."""
for record in event['Records']:
Expand Down Expand Up @@ -43,6 +60,8 @@ async def process_message(message_data: dict) -> None:
"""Process single message from SQS queue."""
import logging
logger = logging.getLogger()
# Enable INFO logging as suggested in issue for better debugging
logger.setLevel(logging.INFO)

config = Config.from_env()
bot = Bot(config.telegram_token)
Expand All @@ -55,19 +74,25 @@ async def process_message(message_data: dict) -> None:
logger.warning("Received update with no message or edited_message")
return

# Extract thread_id and user_message early - needed for all message processing
# (allows handler to override text/thread_id via SQS message_data)
user_message = message_data.get('text') or message.text
thread_id = message_data.get('thread_id') or message.message_thread_id

cmd = config.get_command(message.text)
if cmd:
if config.is_local_command(cmd):
logger.info(
"Handling local command in consumer (fallback path)",
extra={'chat_id': message.chat_id, 'message_id': message.message_id},
)
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
try:
await bot.send_message(
chat_id=message.chat_id,
text=config.local_response(cmd),
message_thread_id=message.message_thread_id,
reply_to_message_id=message.message_id,
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
except Exception:
logger.warning("Failed to send local command response", exc_info=True)
Expand All @@ -82,12 +107,13 @@ async def process_message(message_data: dict) -> None:
'message_id': message.message_id,
},
)
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
try:
await bot.send_message(
chat_id=message.chat_id,
text=config.unknown_command_message(),
message_thread_id=message.message_thread_id,
reply_to_message_id=message.message_id,
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
except Exception:
logger.warning("Failed to send local command response", exc_info=True)
Expand All @@ -97,7 +123,7 @@ async def process_message(message_data: dict) -> None:
await bot.send_chat_action(
chat_id=message.chat_id,
action=ChatAction.TYPING,
message_thread_id=message.message_thread_id,
message_thread_id=thread_id,
)

# Initialize result with default error response
Expand All @@ -108,10 +134,6 @@ async def process_message(message_data: dict) -> None:
'error_message': 'Failed to get response from Agent Server'
}

# Use message_data fields for SQS message (allows handler to override text/thread_id)
user_message = message_data.get('text') or message.text
thread_id = message_data.get('thread_id') or message.message_thread_id

# Call Agent Server
try:
async with httpx.AsyncClient(timeout=600.0) as client:
Expand All @@ -132,21 +154,25 @@ async def process_message(message_data: dict) -> None:

except httpx.TimeoutException:
logger.warning(f"Agent Server timeout for chat_id={message.chat_id}")
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
await bot.send_message(
chat_id=message.chat_id,
text="Request timed out.",
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
raise # Re-raise to trigger SQS retry for transient errors

except Exception as e:
logger.exception(f"Agent Server error for chat_id={message.chat_id}")
error_text = f"Error: {str(e)[:200]}"
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
try:
await bot.send_message(
chat_id=message.chat_id,
text=error_text,
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
except Exception as send_error:
logger.error(f"Failed to send error message to Telegram: {send_error}")
Expand All @@ -162,23 +188,35 @@ async def process_message(message_data: dict) -> None:
text = text[:4000] + "\n\n... (truncated)"

# Send response to Telegram
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)

try:
await bot.send_message(
chat_id=message.chat_id,
text=text,
parse_mode=ParseMode.MARKDOWN_V2,
message_thread_id=thread_id,
reply_to_message_id=message.message_id,
reply_to_message_id=reply_to_id,
)
logger.info(
f"Message sent successfully to chat_id={message.chat_id}, "
f"thread_id={thread_id}, reply_to={reply_to_id}"
)
except BadRequest as e:
if "parse entities" in str(e).lower():
logger.warning(f"Markdown parse error, retrying with escaped text: {e}")
safe_text = escape_markdown(text, version=2)
await bot.send_message(
chat_id=message.chat_id,
text=safe_text,
parse_mode=ParseMode.MARKDOWN_V2,
message_thread_id=thread_id,
reply_to_message_id=message.message_id,
reply_to_message_id=reply_to_id,
)
logger.info(
f"Message sent successfully (escaped) to chat_id={message.chat_id}, "
f"thread_id={thread_id}, reply_to={reply_to_id}"
)
else:
logger.error(f"Failed to send message: {e}")
raise