diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index 99406ef40..73ce704d1 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -46,6 +46,7 @@ def __init__( agent_config: Optional[AgentConfig] = None, thread_id: Optional[str] = None, logger: Optional[ExternalLogger] = None, + multimodal: bool = True, **kwargs, ): """Initialize a CodeAgent. @@ -58,6 +59,10 @@ def __init__( tools: Additional tools to use tags: Tags to add to the agent trace. Must be of the same type. metadata: Metadata to use for the agent. Must be a dictionary. + agent_config: Configuration for the agent + thread_id: Optional thread ID for message history + logger: Optional external logger + multimodal: Whether to use a multimodal model (default: True) **kwargs: Additional LLM configuration options. Supported options: - temperature: Temperature parameter (0-1) - top_p: Top-p sampling parameter (0-1) @@ -65,6 +70,13 @@ def __init__( - max_tokens: Maximum number of tokens to generate """ self.codebase = codebase + + # If multimodal is enabled, ensure we're using a multimodal model + if multimodal and model_provider == "anthropic" and "claude-3" not in model_name: + # Default to Claude 3 Sonnet if multimodal is requested but model isn't Claude 3 + model_name = "claude-3-sonnet-20240229" + print(f"Multimodal support requested, using {model_name}") + self.agent = create_codebase_agent( self.codebase, model_provider=model_provider, diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 0d4795740..57bd6a711 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -1,6 +1,7 @@ """LLM implementation supporting both OpenAI and Anthropic models.""" import os +import re from collections.abc import Sequence from typing import Any, Optional @@ -8,7 +9,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.outputs import ChatResult from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool @@ -106,6 +107,65 @@ def _get_model(self) -> BaseChatModel: msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai, xai" raise ValueError(msg) + def _process_messages_for_multimodal(self, messages: list[BaseMessage]) -> list[BaseMessage]: + """Process messages to handle multimodal content (images). + + This function looks for image URLs in the format [Image: filename](URL) in message content + and converts them to the appropriate format for multimodal models. + + Args: + messages: List of messages to process + + Returns: + Processed messages with multimodal content + """ + processed_messages = [] + + for message in messages: + if not isinstance(message, HumanMessage): + # Only process human messages for now + processed_messages.append(message) + continue + + content = message.content + if isinstance(content, str): + # Check for image URLs in the format [Image: filename](URL) + image_pattern = r"\[Image(?:\s+\d+)?:\s+([^\]]+)\]\(([^)]+)\)" + matches = re.findall(image_pattern, content) + + if not matches: + # No images found, keep the message as is + processed_messages.append(message) + continue + + # Convert to multimodal format + multimodal_content = [] + last_end = 0 + + for match in re.finditer(image_pattern, content): + # Add text before the image + if match.start() > last_end: + multimodal_content.append({"type": "text", "text": content[last_end : match.start()]}) + + # Add the image + image_url = match.group(2) + multimodal_content.append({"type": "image_url", "image_url": {"url": image_url}}) + + last_end = match.end() + + # Add any remaining text after the last image + if last_end < len(content): + multimodal_content.append({"type": "text", "text": content[last_end:]}) + + # Create a new message with multimodal content + new_message = HumanMessage(content=multimodal_content) + processed_messages.append(new_message) + else: + # Content is already in a different format, keep as is + processed_messages.append(message) + + return processed_messages + def _generate( self, messages: list[BaseMessage], @@ -124,6 +184,11 @@ def _generate( Returns: ChatResult containing the generated completion """ + # Process messages for multimodal content if using a multimodal model + if self.model_provider == "anthropic" and "claude-3" in self.model_name: + processed_messages = self._process_messages_for_multimodal(messages) + return self._model._generate(processed_messages, stop=stop, run_manager=run_manager, **kwargs) + return self._model._generate(messages, stop=stop, run_manager=run_manager, **kwargs) def bind_tools(