diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index e27f485da..f613b2fc4 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -1,7 +1,11 @@ import os +import random +import time from typing import TYPE_CHECKING, Optional from uuid import uuid4 +import anthropic +import openai from langchain.tools import BaseTool from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables.config import RunnableConfig @@ -25,6 +29,9 @@ class CodeAgent: project_name: str thread_id: str | None = None config: dict = {} + run_id: str | None = None + instance_id: str | None = None + difficulty: str | None = None def __init__( self, @@ -105,32 +112,62 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: input = {"query": prompt} config = RunnableConfig(configurable={"thread_id": thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=100) - # we stream the steps instead of invoke because it allows us to access intermediate nodes - stream = self.agent.stream(input, config=config, stream_mode="values") - - # Keep track of run IDs from the stream - run_ids = [] - - for s in stream: - if len(s["messages"]) == 0: - message = HumanMessage(content=prompt) - else: - message = s["messages"][-1] - - if isinstance(message, tuple): - print(message) - else: - if isinstance(message, AIMessage) and isinstance(message.content, list) and len(message.content) > 0 and "text" in message.content[0]: - AIMessage(message.content[0]["text"]).pretty_print() - else: - message.pretty_print() - - # Try to extract run ID if available in metadata - if hasattr(message, "additional_kwargs") and "run_id" in message.additional_kwargs: - run_ids.append(message.additional_kwargs["run_id"]) - - # Get the last message content - result = s["final_answer"] + + # Implement retry mechanism for RateLimitError + max_retries = 10 + initial_retry_delay = 30 # seconds + max_retry_delay = 1000 # seconds + retry_count = 0 + + while True: + try: + # we stream the steps instead of invoke because it allows us to access intermediate nodes + stream = self.agent.stream(input, config=config, stream_mode="values") + + # Keep track of run IDs from the stream + run_ids = [] + + for s in stream: + if len(s["messages"]) == 0: + message = HumanMessage(content=prompt) + else: + message = s["messages"][-1] + + if isinstance(message, tuple): + print(message) + else: + if isinstance(message, AIMessage) and isinstance(message.content, list) and len(message.content) > 0 and "text" in message.content[0]: + AIMessage(message.content[0]["text"]).pretty_print() + else: + message.pretty_print() + + # Try to extract run ID if available in metadata + if hasattr(message, "additional_kwargs") and "run_id" in message.additional_kwargs: + run_ids.append(message.additional_kwargs["run_id"]) + + # Get the last message content + result = s["final_answer"] + + # Successfully completed, break out of retry loop + break + + except (anthropic.RateLimitError, openai.RateLimitError) as e: + retry_count += 1 + if retry_count > max_retries: + msg = f"Maximum retry attempts ({max_retries}) exceeded for RateLimitError: {e!s}" + raise Exception(msg) + + # Calculate backoff with exponential increase and some jitter + retry_delay = min(initial_retry_delay * (2 ** (retry_count - 1)), max_retry_delay) + jitter = retry_delay * 0.1 * (2 * (0.5 - random.random())) + retry_delay = retry_delay + jitter + + print(f"Rate limit exceeded. Retrying in {retry_delay:.1f} seconds... (Attempt {retry_count}/{max_retries})") + time.sleep(retry_delay) + continue + except Exception as e: + # Re-raise other exceptions + raise e # Try to find run IDs in the LangSmith client's recent runs try: