Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 63 additions & 26 deletions src/codegen/agents/code_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import random
import time
from typing import TYPE_CHECKING, Optional
from uuid import uuid4

import anthropic
import openai
from langchain.tools import BaseTool
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.config import RunnableConfig
Expand All @@ -25,6 +29,9 @@ class CodeAgent:
project_name: str
thread_id: str | None = None
config: dict = {}
run_id: str | None = None
instance_id: str | None = None
difficulty: str | None = None

def __init__(
self,
Expand Down Expand Up @@ -105,32 +112,62 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
input = {"query": prompt}

config = RunnableConfig(configurable={"thread_id": thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=100)
# we stream the steps instead of invoke because it allows us to access intermediate nodes
stream = self.agent.stream(input, config=config, stream_mode="values")

# Keep track of run IDs from the stream
run_ids = []

for s in stream:
if len(s["messages"]) == 0:
message = HumanMessage(content=prompt)
else:
message = s["messages"][-1]

if isinstance(message, tuple):
print(message)
else:
if isinstance(message, AIMessage) and isinstance(message.content, list) and len(message.content) > 0 and "text" in message.content[0]:
AIMessage(message.content[0]["text"]).pretty_print()
else:
message.pretty_print()

# Try to extract run ID if available in metadata
if hasattr(message, "additional_kwargs") and "run_id" in message.additional_kwargs:
run_ids.append(message.additional_kwargs["run_id"])

# Get the last message content
result = s["final_answer"]

# Implement retry mechanism for RateLimitError
max_retries = 10
initial_retry_delay = 30 # seconds
max_retry_delay = 1000 # seconds
retry_count = 0

while True:
try:
# we stream the steps instead of invoke because it allows us to access intermediate nodes
stream = self.agent.stream(input, config=config, stream_mode="values")

# Keep track of run IDs from the stream
run_ids = []

for s in stream:
if len(s["messages"]) == 0:
message = HumanMessage(content=prompt)
else:
message = s["messages"][-1]

if isinstance(message, tuple):
print(message)
else:
if isinstance(message, AIMessage) and isinstance(message.content, list) and len(message.content) > 0 and "text" in message.content[0]:
AIMessage(message.content[0]["text"]).pretty_print()
else:
message.pretty_print()

# Try to extract run ID if available in metadata
if hasattr(message, "additional_kwargs") and "run_id" in message.additional_kwargs:
run_ids.append(message.additional_kwargs["run_id"])

# Get the last message content
result = s["final_answer"]

# Successfully completed, break out of retry loop
break

except (anthropic.RateLimitError, openai.RateLimitError) as e:
retry_count += 1
if retry_count > max_retries:
msg = f"Maximum retry attempts ({max_retries}) exceeded for RateLimitError: {e!s}"
raise Exception(msg)

# Calculate backoff with exponential increase and some jitter
retry_delay = min(initial_retry_delay * (2 ** (retry_count - 1)), max_retry_delay)
jitter = retry_delay * 0.1 * (2 * (0.5 - random.random()))
retry_delay = retry_delay + jitter

print(f"Rate limit exceeded. Retrying in {retry_delay:.1f} seconds... (Attempt {retry_count}/{max_retries})")
time.sleep(retry_delay)
continue
except Exception as e:
# Re-raise other exceptions
raise e

# Try to find run IDs in the LangSmith client's recent runs
try:
Expand Down
Loading