diff --git a/README.md b/README.md index 740909f5e5..c4feb65c18 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ OpenEvolve implements a sophisticated **evolutionary coding pipeline** that goes
Advanced LLM Integration -- **Universal API**: Works with OpenAI, Google, local models, and proxies +- **Universal API**: Works with OpenAI, Google, Claude Code CLI, local models, and proxies - **Intelligent Ensembles**: Weighted combinations with sophisticated fallback - **Test-Time Compute**: Enhanced reasoning through proxy systems (see [OptiLLM setup](#llm-provider-setup)) - **Plugin Ecosystem**: Support for advanced reasoning plugins @@ -233,7 +233,7 @@ OpenEvolve implements a sophisticated **evolutionary coding pipeline** that goes ### Requirements - **Python**: 3.10+ -- **LLM Access**: Any OpenAI-compatible API +- **LLM Access**: Any OpenAI-compatible API, or [Claude Code CLI](https://docs.anthropic.com/en/docs/claude-code) - **Optional**: Docker for containerized runs ### Installation Options @@ -356,6 +356,35 @@ llm:
+
+🔮 Claude Code CLI (No API Key) + +Use the [Claude Code CLI](https://docs.anthropic.com/en/docs/claude-code) as the LLM backend — no API keys needed, authentication uses the CLI's OAuth session. + +```bash +# Install and authenticate +npm install -g @anthropic-ai/claude-code +claude login +``` + +```yaml +# config.yaml +llm: + provider: "claude_code" + models: + - name: "sonnet" + weight: 0.8 + max_tokens: 16000 + max_budget_usd: 1.0 + - name: "haiku" + weight: 0.2 + max_tokens: 8000 +``` + +See the [Claude Code quickstart example](examples/claude_code_quickstart/) for a complete walkthrough. + +
+ ## Examples Gallery
diff --git a/examples/claude_code_quickstart/README.md b/examples/claude_code_quickstart/README.md new file mode 100644 index 0000000000..585f475648 --- /dev/null +++ b/examples/claude_code_quickstart/README.md @@ -0,0 +1,72 @@ +# Claude Code CLI Quickstart + +This example shows how to use the [Claude Code CLI](https://docs.anthropic.com/en/docs/claude-code) as the LLM backend for OpenEvolve. No API keys are needed — authentication uses the CLI's existing OAuth session. + +## Prerequisites + +1. **Install Claude Code CLI:** + ```bash + npm install -g @anthropic-ai/claude-code + ``` + +2. **Authenticate:** + ```bash + claude login + ``` + +3. **Install OpenEvolve:** + ```bash + pip install openevolve + ``` + +## Run + +```bash +python openevolve-run.py \ + examples/claude_code_quickstart/initial_program.py \ + examples/claude_code_quickstart/evaluator.py \ + --config examples/claude_code_quickstart/config.yaml \ + --iterations 50 +``` + +## How It Works + +The `config.yaml` sets `provider: "claude_code"` which routes all LLM calls through the `claude -p` subprocess instead of the OpenAI-compatible API. The CLI handles authentication, model selection, and billing. + +### Key Config Options + +| Field | Description | Default | +|-------|-------------|---------| +| `provider` | Set to `"claude_code"` to use the CLI backend | `"openai"` | +| `name` | Claude model name (`sonnet`, `haiku`, `opus`) | `"sonnet"` | +| `max_budget_usd` | Per-call spending cap in USD | `1.0` | +| `timeout` | CLI timeout in seconds | `300` | +| `retries` | Number of retry attempts on failure | `3` | +| `retry_delay` | Seconds between retries | `5` | + +### Ensemble Example + +You can mix Claude models in an ensemble, just like with OpenAI models: + +```yaml +llm: + provider: "claude_code" + models: + - name: "sonnet" + weight: 0.8 + max_tokens: 16000 + - name: "haiku" + weight: 0.2 + max_tokens: 8000 +``` + +### Programmatic Usage + +You can also inject the Claude Code backend at runtime without modifying config files: + +```python +from openevolve.llm.claude_code import init_claude_code_client + +for model_cfg in config.llm.models: + model_cfg.init_client = init_claude_code_client +``` diff --git a/examples/claude_code_quickstart/config.yaml b/examples/claude_code_quickstart/config.yaml new file mode 100644 index 0000000000..d1c554e561 --- /dev/null +++ b/examples/claude_code_quickstart/config.yaml @@ -0,0 +1,57 @@ +# Configuration for function minimization using Claude Code CLI as the LLM backend. +# No API keys needed — authentication uses `claude login` (OAuth session). +# +# Prerequisites: +# 1. Install Claude Code CLI: npm install -g @anthropic-ai/claude-code +# 2. Authenticate: claude login +# +# Run: +# python openevolve-run.py \ +# examples/claude_code_quickstart/initial_program.py \ +# examples/claude_code_quickstart/evaluator.py \ +# --config examples/claude_code_quickstart/config.yaml \ +# --iterations 50 + +max_iterations: 50 +checkpoint_interval: 10 + +llm: + provider: "claude_code" + models: + - name: "sonnet" + weight: 0.8 + max_tokens: 16000 + timeout: 300 + max_budget_usd: 1.0 + - name: "haiku" + weight: 0.2 + max_tokens: 8000 + timeout: 120 + max_budget_usd: 0.5 + retries: 3 + retry_delay: 5 + +prompt: + system_message: > + You are an expert programmer specializing in optimization algorithms. + Your task is to improve a function minimization algorithm to find the + global minimum of a complex function with many local minima. + The function is f(x, y) = sin(x) * cos(y) + sin(x*y) + (x^2 + y^2)/20. + Focus on improving the search_algorithm function to reliably find the + global minimum, escaping local minima that might trap simple algorithms. + +database: + population_size: 50 + archive_size: 20 + num_islands: 3 + elite_selection_ratio: 0.2 + exploitation_ratio: 0.7 + similarity_threshold: 0.99 + +evaluator: + timeout: 60 + cascade_thresholds: [1.3] + parallel_evaluations: 3 + +diff_based_evolution: true +max_code_length: 20000 diff --git a/examples/claude_code_quickstart/evaluator.py b/examples/claude_code_quickstart/evaluator.py new file mode 100644 index 0000000000..f16318125b --- /dev/null +++ b/examples/claude_code_quickstart/evaluator.py @@ -0,0 +1,493 @@ +""" +Evaluator for the function minimization example +""" + +import importlib.util +import numpy as np +import time +import concurrent.futures +import traceback +import signal +from openevolve.evaluation_result import EvaluationResult + + +def run_with_timeout(func, args=(), kwargs={}, timeout_seconds=5): + """ + Run a function with a timeout using concurrent.futures + + Args: + func: Function to run + args: Arguments to pass to the function + kwargs: Keyword arguments to pass to the function + timeout_seconds: Timeout in seconds + + Returns: + Result of the function or raises TimeoutError + """ + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(func, *args, **kwargs) + try: + result = future.result(timeout=timeout_seconds) + return result + except concurrent.futures.TimeoutError: + raise TimeoutError(f"Function timed out after {timeout_seconds} seconds") + + +def safe_float(value): + """Convert a value to float safely""" + try: + return float(value) + except (TypeError, ValueError): + print(f"Warning: Could not convert {value} of type {type(value)} to float") + return 0.0 + + +def evaluate(program_path): + """ + Evaluate the program by running it multiple times and checking how close + it gets to the known global minimum. + + Args: + program_path: Path to the program file + + Returns: + Dictionary of metrics + """ + # Known global minimum (approximate) + GLOBAL_MIN_X = -1.704 + GLOBAL_MIN_Y = 0.678 + GLOBAL_MIN_VALUE = -1.519 + + try: + # Load the program + spec = importlib.util.spec_from_file_location("program", program_path) + program = importlib.util.module_from_spec(spec) + spec.loader.exec_module(program) + + # Check if the required function exists + if not hasattr(program, "run_search"): + print(f"Error: program does not have 'run_search' function") + + error_artifacts = { + "error_type": "MissingFunction", + "error_message": "Program is missing required 'run_search' function", + "suggestion": "Make sure your program includes a function named 'run_search' that returns (x, y, value) or (x, y)" + } + + return EvaluationResult( + metrics={ + "value_score": 0.0, + "distance_score": 0.0, + "reliability_score": 0.0, + "combined_score": 0.0, + "error": "Missing run_search function", + }, + artifacts=error_artifacts + ) + + # Run multiple trials + num_trials = 10 + x_values = [] + y_values = [] + values = [] + distances = [] + times = [] + success_count = 0 + + for trial in range(num_trials): + try: + start_time = time.time() + + # Run with timeout + result = run_with_timeout(program.run_search, timeout_seconds=5) + + # Handle different result formats + if isinstance(result, tuple): + if len(result) == 3: + x, y, value = result + elif len(result) == 2: + # Assume it's (x, y) and calculate value + x, y = result + # Calculate the function value since it wasn't returned + value = np.sin(x) * np.cos(y) + np.sin(x * y) + (x**2 + y**2) / 20 + print(f"Trial {trial}: Got 2 values, calculated function value: {value}") + else: + print( + f"Trial {trial}: Invalid result format, expected tuple of 2 or 3 values but got {len(result)}" + ) + continue + else: + print( + f"Trial {trial}: Invalid result format, expected tuple but got {type(result)}" + ) + continue + + end_time = time.time() + + # Ensure all values are float + x = safe_float(x) + y = safe_float(y) + value = safe_float(value) + + # Check if the result is valid (not NaN or infinite) + if ( + np.isnan(x) + or np.isnan(y) + or np.isnan(value) + or np.isinf(x) + or np.isinf(y) + or np.isinf(value) + ): + print(f"Trial {trial}: Invalid result, got x={x}, y={y}, value={value}") + continue + + # Calculate metrics + x_diff = x - GLOBAL_MIN_X + y_diff = y - GLOBAL_MIN_Y + distance_to_global = np.sqrt(x_diff**2 + y_diff**2) + + x_values.append(x) + y_values.append(y) + values.append(value) + distances.append(distance_to_global) + times.append(end_time - start_time) + success_count += 1 + + except TimeoutError as e: + print(f"Trial {trial}: {str(e)}") + continue + except IndexError as e: + # Specifically handle IndexError which often happens with early termination checks + print(f"Trial {trial}: IndexError - {str(e)}") + print( + "This is likely due to a list index check before the list is fully populated." + ) + continue + except Exception as e: + print(f"Trial {trial}: Error - {str(e)}") + print(traceback.format_exc()) + continue + + # If all trials failed, return zero scores + if success_count == 0: + error_artifacts = { + "error_type": "AllTrialsFailed", + "error_message": f"All {num_trials} trials failed - common issues: timeouts, crashes, or invalid return values", + "suggestion": "Check for infinite loops, ensure function returns (x, y) or (x, y, value), and verify algorithm terminates within time limit" + } + + return EvaluationResult( + metrics={ + "value_score": 0.0, + "distance_score": 0.0, + "reliability_score": 0.0, + "combined_score": 0.0, + "error": "All trials failed", + }, + artifacts=error_artifacts + ) + + # Calculate metrics + avg_value = float(np.mean(values)) + avg_distance = float(np.mean(distances)) + avg_time = float(np.mean(times)) if times else 1.0 + + # Convert to scores (higher is better) + value_score = float(1.0 / (1.0 + abs(avg_value - GLOBAL_MIN_VALUE))) + distance_score = float(1.0 / (1.0 + avg_distance)) + + # Add reliability score based on success rate + reliability_score = float(success_count / num_trials) + + # Calculate solution quality based on distance to global minimum + if avg_distance < 0.5: # Very close to the correct solution + solution_quality_multiplier = 1.5 # 50% bonus + elif avg_distance < 1.5: # In the right region + solution_quality_multiplier = 1.2 # 20% bonus + elif avg_distance < 3.0: # Getting closer + solution_quality_multiplier = 1.0 # No adjustment + else: # Not finding the right region + solution_quality_multiplier = 0.7 # 30% penalty + + # Calculate combined score that prioritizes finding the global minimum + # Base score from value and distance, then apply solution quality multiplier + base_score = 0.5 * value_score + 0.3 * distance_score + 0.2 * reliability_score + combined_score = float(base_score * solution_quality_multiplier) + + # Add artifacts for successful runs + artifacts = { + "convergence_info": f"Converged in {num_trials} trials with {success_count} successes", + "best_position": f"Final position: x={x_values[-1]:.4f}, y={y_values[-1]:.4f}" if x_values else "No successful trials", + "average_distance_to_global": f"{avg_distance:.4f}", + "search_efficiency": f"Success rate: {reliability_score:.2%}" + } + + return EvaluationResult( + metrics={ + "value_score": value_score, + "distance_score": distance_score, + "reliability_score": reliability_score, + "combined_score": combined_score, + }, + artifacts=artifacts + ) + except Exception as e: + print(f"Evaluation failed completely: {str(e)}") + print(traceback.format_exc()) + + # Create error artifacts + error_artifacts = { + "error_type": type(e).__name__, + "error_message": str(e), + "full_traceback": traceback.format_exc(), + "suggestion": "Check for syntax errors or missing imports in the generated code" + } + + return EvaluationResult( + metrics={ + "value_score": 0.0, + "distance_score": 0.0, + "reliability_score": 0.0, + "combined_score": 0.0, + "error": str(e), + }, + artifacts=error_artifacts + ) + + +# Stage-based evaluation for cascade evaluation +def evaluate_stage1(program_path): + """First stage evaluation with fewer trials""" + # Known global minimum (approximate) + GLOBAL_MIN_X = float(-1.704) + GLOBAL_MIN_Y = float(0.678) + GLOBAL_MIN_VALUE = float(-1.519) + + # Quick check to see if the program runs without errors + try: + # Load the program + spec = importlib.util.spec_from_file_location("program", program_path) + program = importlib.util.module_from_spec(spec) + spec.loader.exec_module(program) + + # Check if the required function exists + if not hasattr(program, "run_search"): + print(f"Stage 1 validation: Program does not have 'run_search' function") + + error_artifacts = { + "error_type": "MissingFunction", + "error_message": "Stage 1: Program is missing required 'run_search' function", + "suggestion": "Make sure your program includes a function named 'run_search' that returns (x, y, value) or (x, y)" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.0, + "combined_score": 0.0, + "error": "Missing run_search function" + }, + artifacts=error_artifacts + ) + + try: + # Run a single trial with timeout + result = run_with_timeout(program.run_search, timeout_seconds=5) + + # Handle different result formats + if isinstance(result, tuple): + if len(result) == 3: + x, y, value = result + elif len(result) == 2: + # Assume it's (x, y) and calculate value + x, y = result + # Calculate the function value since it wasn't returned + value = np.sin(x) * np.cos(y) + np.sin(x * y) + (x**2 + y**2) / 20 + print(f"Stage 1: Got 2 values, calculated function value: {value}") + else: + print( + f"Stage 1: Invalid result format, expected tuple of 2 or 3 values but got {len(result)}" + ) + + error_artifacts = { + "error_type": "InvalidReturnFormat", + "error_message": f"Stage 1: Function returned tuple with {len(result)} values, expected 2 or 3", + "suggestion": "run_search() must return (x, y) or (x, y, value) - check your return statement" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.0, + "combined_score": 0.0, + "error": "Invalid result format" + }, + artifacts=error_artifacts + ) + else: + print(f"Stage 1: Invalid result format, expected tuple but got {type(result)}") + + error_artifacts = { + "error_type": "InvalidReturnType", + "error_message": f"Stage 1: Function returned {type(result)}, expected tuple", + "suggestion": "run_search() must return a tuple like (x, y) or (x, y, value), not a single value or other type" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.0, + "combined_score": 0.0, + "error": "Invalid result format" + }, + artifacts=error_artifacts + ) + + # Ensure all values are float + x = safe_float(x) + y = safe_float(y) + value = safe_float(value) + + # Check if the result is valid + if ( + np.isnan(x) + or np.isnan(y) + or np.isnan(value) + or np.isinf(x) + or np.isinf(y) + or np.isinf(value) + ): + print(f"Stage 1 validation: Invalid result, got x={x}, y={y}, value={value}") + + error_artifacts = { + "error_type": "InvalidResultValues", + "error_message": f"Stage 1: Got invalid values - x={x}, y={y}, value={value}", + "suggestion": "Function returned NaN or infinite values. Check for division by zero, invalid math operations, or uninitialized variables" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.5, + "combined_score": 0.0, + "error": "Invalid result values" + }, + artifacts=error_artifacts + ) + + # Calculate distance safely + x_diff = float(x) - GLOBAL_MIN_X + y_diff = float(y) - GLOBAL_MIN_Y + distance = float(np.sqrt(x_diff**2 + y_diff**2)) + + # Calculate value-based score + value_score = float(1.0 / (1.0 + abs(value - GLOBAL_MIN_VALUE))) + distance_score = float(1.0 / (1.0 + distance)) + + # Calculate solution quality based on distance to global minimum + if distance < 0.5: # Very close to the correct solution + solution_quality_multiplier = 1.4 # 40% bonus + elif distance < 1.5: # In the right region + solution_quality_multiplier = 1.15 # 15% bonus + elif distance < 3.0: # Getting closer + solution_quality_multiplier = 1.0 # No adjustment + else: # Not finding the right region + solution_quality_multiplier = 0.8 # 20% penalty + + # Calculate combined score for stage 1 + base_score = 0.6 * value_score + 0.4 * distance_score + combined_score = float(base_score * solution_quality_multiplier) + + # Add artifacts for successful stage 1 + stage1_artifacts = { + "stage1_result": f"Found solution at x={x:.4f}, y={y:.4f} with value={value:.4f}", + "distance_to_global": f"{distance:.4f}", + "solution_quality": f"Distance < 0.5: Very close" if distance < 0.5 else f"Distance < 1.5: Good region" if distance < 1.5 else "Could be improved" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 1.0, + "value_score": value_score, + "distance_score": distance_score, + "combined_score": combined_score, + }, + artifacts=stage1_artifacts + ) + except TimeoutError as e: + print(f"Stage 1 evaluation timed out: {e}") + + error_artifacts = { + "error_type": "TimeoutError", + "error_message": "Stage 1: Function execution exceeded 5 second timeout", + "suggestion": "Function is likely stuck in infinite loop or doing too much computation. Try reducing iterations or adding early termination conditions" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.0, + "combined_score": 0.0, + "error": "Timeout" + }, + artifacts=error_artifacts + ) + except IndexError as e: + # Specifically handle IndexError which often happens with early termination checks + print(f"Stage 1 evaluation failed with IndexError: {e}") + print("This is likely due to a list index check before the list is fully populated.") + + error_artifacts = { + "error_type": "IndexError", + "error_message": f"Stage 1: {str(e)}", + "suggestion": "List index out of range - likely accessing empty list or wrong index. Check list initialization and bounds" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.0, + "combined_score": 0.0, + "error": f"IndexError: {str(e)}" + }, + artifacts=error_artifacts + ) + except Exception as e: + print(f"Stage 1 evaluation failed: {e}") + print(traceback.format_exc()) + + error_artifacts = { + "error_type": type(e).__name__, + "error_message": f"Stage 1: {str(e)}", + "full_traceback": traceback.format_exc(), + "suggestion": "Unexpected error occurred. Check the traceback for specific issue" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.0, + "combined_score": 0.0, + "error": str(e) + }, + artifacts=error_artifacts + ) + + except Exception as e: + print(f"Stage 1 evaluation failed: {e}") + print(traceback.format_exc()) + + error_artifacts = { + "error_type": type(e).__name__, + "error_message": f"Stage 1 outer exception: {str(e)}", + "full_traceback": traceback.format_exc(), + "suggestion": "Critical error during stage 1 evaluation. Check program syntax and imports" + } + + return EvaluationResult( + metrics={ + "runs_successfully": 0.0, + "combined_score": 0.0, + "error": str(e) + }, + artifacts=error_artifacts + ) + + +def evaluate_stage2(program_path): + """Second stage evaluation with more thorough testing""" + # Full evaluation as in the main evaluate function + return evaluate(program_path) diff --git a/examples/claude_code_quickstart/initial_program.py b/examples/claude_code_quickstart/initial_program.py new file mode 100644 index 0000000000..670c02cc45 --- /dev/null +++ b/examples/claude_code_quickstart/initial_program.py @@ -0,0 +1,51 @@ +# EVOLVE-BLOCK-START +"""Function minimization example for OpenEvolve""" +import numpy as np + + +def search_algorithm(iterations=1000, bounds=(-5, 5)): + """ + A simple random search algorithm that often gets stuck in local minima. + + Args: + iterations: Number of iterations to run + bounds: Bounds for the search space (min, max) + + Returns: + Tuple of (best_x, best_y, best_value) + """ + # Initialize with a random point + best_x = np.random.uniform(bounds[0], bounds[1]) + best_y = np.random.uniform(bounds[0], bounds[1]) + best_value = evaluate_function(best_x, best_y) + + for _ in range(iterations): + # Simple random search + x = np.random.uniform(bounds[0], bounds[1]) + y = np.random.uniform(bounds[0], bounds[1]) + value = evaluate_function(x, y) + + if value < best_value: + best_value = value + best_x, best_y = x, y + + return best_x, best_y, best_value + + +# EVOLVE-BLOCK-END + + +# This part remains fixed (not evolved) +def evaluate_function(x, y): + """The complex function we're trying to minimize""" + return np.sin(x) * np.cos(y) + np.sin(x * y) + (x**2 + y**2) / 20 + + +def run_search(): + x, y, value = search_algorithm() + return x, y, value + + +if __name__ == "__main__": + x, y, value = run_search() + print(f"Found minimum at ({x}, {y}) with value {value}") diff --git a/openevolve/config.py b/openevolve/config.py index bef193da21..fd03ed7a98 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -56,6 +56,9 @@ class LLMModelConfig: api_key: Optional[str] = None name: str = None + # LLM provider: "openai" (default), "claude_code" (Claude Code CLI) + provider: Optional[str] = None + # Custom LLM client init_client: Optional[Callable] = None @@ -79,6 +82,9 @@ class LLMModelConfig: # Reasoning parameters reasoning_effort: Optional[str] = None + # Claude Code CLI budget per call (USD) + max_budget_usd: Optional[float] = None + # Manual mode (human-in-the-loop) manual_mode: Optional[bool] = None _manual_queue_dir: Optional[str] = None diff --git a/openevolve/llm/__init__.py b/openevolve/llm/__init__.py index 26bbef5676..856843d4b0 100644 --- a/openevolve/llm/__init__.py +++ b/openevolve/llm/__init__.py @@ -5,5 +5,12 @@ from openevolve.llm.base import LLMInterface from openevolve.llm.ensemble import LLMEnsemble from openevolve.llm.openai import OpenAILLM +from openevolve.llm.claude_code import ClaudeCodeLLM, init_claude_code_client -__all__ = ["LLMInterface", "OpenAILLM", "LLMEnsemble"] +__all__ = [ + "LLMInterface", + "OpenAILLM", + "ClaudeCodeLLM", + "init_claude_code_client", + "LLMEnsemble", +] diff --git a/openevolve/llm/claude_code.py b/openevolve/llm/claude_code.py new file mode 100644 index 0000000000..ddbe2e3b7f --- /dev/null +++ b/openevolve/llm/claude_code.py @@ -0,0 +1,138 @@ +""" +Claude Code CLI interface for LLMs. + +Uses the Claude Code CLI (`claude -p`) as a non-interactive LLM backend, +enabling OpenEvolve to run with Anthropic's Claude models without requiring +direct API keys — authentication is handled by the CLI's OAuth session. + +Usage in config.yaml: + llm: + provider: "claude_code" + models: + - name: "sonnet" + weight: 1.0 + max_tokens: 16000 + timeout: 300 + +Or inject programmatically: + from openevolve.llm.claude_code import init_claude_code_client + for model_cfg in config.llm.models: + model_cfg.init_client = init_claude_code_client +""" + +import asyncio +import logging +import subprocess +from typing import Dict, List, Optional + +from openevolve.llm.base import LLMInterface + +logger = logging.getLogger(__name__) + + +class ClaudeCodeLLM(LLMInterface): + """LLM interface that uses the Claude Code CLI for generation. + + Requires `claude` CLI to be installed and authenticated + (run `claude login` first). + """ + + def __init__(self, model_cfg=None): + self.model = getattr(model_cfg, "name", "sonnet") + self.system_message = getattr(model_cfg, "system_message", None) + self.max_tokens = getattr(model_cfg, "max_tokens", 16000) + self.timeout = getattr(model_cfg, "timeout", 300) + self.weight = getattr(model_cfg, "weight", 1.0) + self.retries = getattr(model_cfg, "retries", 3) + self.retry_delay = getattr(model_cfg, "retry_delay", 5) + self.max_budget_usd = getattr(model_cfg, "max_budget_usd", 1.0) + self.cwd = getattr(model_cfg, "cwd", None) + logger.info(f"Initialized ClaudeCodeLLM with model: {self.model}") + + async def generate(self, prompt: str, **kwargs) -> str: + sys_msg = kwargs.pop("system_message", self.system_message) or "" + return await self.generate_with_context( + system_message=sys_msg, + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + + async def generate_with_context( + self, system_message: str, messages: List[Dict[str, str]], **kwargs + ) -> str: + user_content = "\n\n".join( + m.get("content", "") for m in messages if m.get("role") == "user" + ) + + cmd = [ + "claude", + "-p", + "--model", + self.model, + "--no-session-persistence", + "--output-format", + "text", + ] + if system_message: + cmd.extend(["--system-prompt", system_message]) + + budget = kwargs.get("max_budget_usd", self.max_budget_usd) + cmd.extend(["--max-budget-usd", str(budget)]) + + cmd.append(user_content) + + timeout = kwargs.get("timeout", self.timeout) + retries = kwargs.get("retries", self.retries) + retry_delay = kwargs.get("retry_delay", self.retry_delay) + + loop = asyncio.get_event_loop() + for attempt in range(retries + 1): + try: + result = await asyncio.wait_for( + loop.run_in_executor(None, lambda: self._run_cli(cmd, timeout)), + timeout=timeout + 30, + ) + return result + except asyncio.TimeoutError: + if attempt < retries: + logger.warning( + f"Claude Code CLI timeout on attempt {attempt + 1}/{retries + 1}. Retrying..." + ) + await asyncio.sleep(retry_delay) + else: + logger.error(f"All {retries + 1} attempts failed with timeout") + raise + except Exception as e: + if attempt < retries: + logger.warning( + f"Claude Code CLI error on attempt {attempt + 1}/{retries + 1}: {e}. Retrying..." + ) + await asyncio.sleep(retry_delay) + else: + logger.error(f"All {retries + 1} attempts failed with error: {e}") + raise + + def _run_cli(self, cmd: list, timeout: int) -> str: + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + cwd=self.cwd, + ) + if result.returncode != 0: + stderr = result.stderr.strip() + if stderr: + logger.warning(f"Claude CLI stderr: {stderr[:500]}") + output = result.stdout.strip() + if not output: + raise RuntimeError(f"Empty response from Claude CLI. stderr: {result.stderr[:500]}") + return output + except subprocess.TimeoutExpired: + raise asyncio.TimeoutError("Claude CLI subprocess timed out") + + +def init_claude_code_client(model_cfg): + """Factory function compatible with OpenEvolve's init_client config hook.""" + return ClaudeCodeLLM(model_cfg) diff --git a/openevolve/llm/ensemble.py b/openevolve/llm/ensemble.py index e3c4716735..b9161382a1 100644 --- a/openevolve/llm/ensemble.py +++ b/openevolve/llm/ensemble.py @@ -13,6 +13,25 @@ logger = logging.getLogger(__name__) +_PROVIDER_REGISTRY = { + "openai": lambda cfg: OpenAILLM(cfg), +} + +try: + from openevolve.llm.claude_code import ClaudeCodeLLM + _PROVIDER_REGISTRY["claude_code"] = lambda cfg: ClaudeCodeLLM(cfg) +except ImportError: + pass + + +def _create_model(model_cfg: LLMModelConfig) -> LLMInterface: + if model_cfg.init_client: + return model_cfg.init_client(model_cfg) + provider = getattr(model_cfg, "provider", None) + if provider and provider in _PROVIDER_REGISTRY: + return _PROVIDER_REGISTRY[provider](model_cfg) + return OpenAILLM(model_cfg) + class LLMEnsemble: """Ensemble of LLMs""" @@ -21,10 +40,7 @@ def __init__(self, models_cfg: List[LLMModelConfig]): self.models_cfg = models_cfg # Initialize models from the configuration - self.models = [ - model_cfg.init_client(model_cfg) if model_cfg.init_client else OpenAILLM(model_cfg) - for model_cfg in models_cfg - ] + self.models = [_create_model(model_cfg) for model_cfg in models_cfg] # Extract and normalize model weights self.weights = [model.weight for model in models_cfg] diff --git a/tests/test_claude_code_llm.py b/tests/test_claude_code_llm.py new file mode 100644 index 0000000000..b33a9d62e8 --- /dev/null +++ b/tests/test_claude_code_llm.py @@ -0,0 +1,160 @@ +"""Tests for the Claude Code CLI LLM backend.""" + +import asyncio +import unittest +from unittest.mock import MagicMock, patch + +from openevolve.llm.claude_code import ClaudeCodeLLM, init_claude_code_client + + +class TestClaudeCodeLLM(unittest.TestCase): + def _make_cfg(self, **overrides): + cfg = MagicMock() + cfg.name = overrides.get("name", "sonnet") + cfg.system_message = overrides.get("system_message", None) + cfg.max_tokens = overrides.get("max_tokens", 16000) + cfg.timeout = overrides.get("timeout", 60) + cfg.weight = overrides.get("weight", 1.0) + cfg.retries = overrides.get("retries", 3) + cfg.retry_delay = overrides.get("retry_delay", 5) + cfg.max_budget_usd = overrides.get("max_budget_usd", 1.0) + cfg.cwd = overrides.get("cwd", None) + return cfg + + def test_init_defaults(self): + llm = ClaudeCodeLLM(self._make_cfg()) + self.assertEqual(llm.model, "sonnet") + self.assertEqual(llm.max_tokens, 16000) + self.assertEqual(llm.timeout, 60) + self.assertEqual(llm.weight, 1.0) + + def test_init_with_custom_model(self): + llm = ClaudeCodeLLM(self._make_cfg(name="opus")) + self.assertEqual(llm.model, "opus") + + def test_factory_function(self): + cfg = self._make_cfg() + llm = init_claude_code_client(cfg) + self.assertIsInstance(llm, ClaudeCodeLLM) + self.assertEqual(llm.model, "sonnet") + + @patch("openevolve.llm.claude_code.subprocess.run") + def test_generate_calls_cli(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="Generated response text", stderr="") + llm = ClaudeCodeLLM(self._make_cfg(timeout=10)) + result = asyncio.run(llm.generate("test prompt")) + self.assertEqual(result, "Generated response text") + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + self.assertEqual(cmd[0], "claude") + self.assertIn("-p", cmd) + self.assertIn("--model", cmd) + self.assertIn("sonnet", cmd) + self.assertIn("test prompt", cmd) + + @patch("openevolve.llm.claude_code.subprocess.run") + def test_system_message_passed(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="response", stderr="") + llm = ClaudeCodeLLM(self._make_cfg(timeout=10)) + asyncio.run(llm.generate("prompt", system_message="You are an expert.")) + cmd = mock_run.call_args[0][0] + idx = cmd.index("--system-prompt") + self.assertEqual(cmd[idx + 1], "You are an expert.") + + @patch("openevolve.llm.claude_code.subprocess.run") + def test_empty_response_raises(self, mock_run): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="error msg") + llm = ClaudeCodeLLM(self._make_cfg(timeout=10, retries=0)) + with self.assertRaises(RuntimeError): + asyncio.run(llm.generate("test prompt")) + + @patch("openevolve.llm.claude_code.subprocess.run") + def test_retry_on_failure(self, mock_run): + mock_run.side_effect = [ + MagicMock(returncode=1, stdout="", stderr="transient error"), + MagicMock(returncode=0, stdout="success after retry", stderr=""), + ] + llm = ClaudeCodeLLM(self._make_cfg(timeout=10, retries=1, retry_delay=0)) + result = asyncio.run(llm.generate("test prompt", retry_delay=0)) + self.assertEqual(result, "success after retry") + self.assertEqual(mock_run.call_count, 2) + + @patch("openevolve.llm.claude_code.subprocess.run") + def test_retries_exhausted_raises(self, mock_run): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="persistent error") + llm = ClaudeCodeLLM(self._make_cfg(timeout=10, retries=2, retry_delay=0)) + with self.assertRaises(RuntimeError): + asyncio.run(llm.generate("test prompt", retry_delay=0)) + self.assertEqual(mock_run.call_count, 3) + + @patch("openevolve.llm.claude_code.subprocess.run") + def test_generate_with_context(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="ctx response", stderr="") + llm = ClaudeCodeLLM(self._make_cfg(timeout=10)) + result = asyncio.run( + llm.generate_with_context( + system_message="sys", + messages=[ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ignored"}, + {"role": "user", "content": "second"}, + ], + ) + ) + self.assertEqual(result, "ctx response") + cmd = mock_run.call_args[0][0] + self.assertIn("first\n\nsecond", cmd[-1]) + + +class TestMaxBudgetConfig(unittest.TestCase): + def test_max_budget_usd_in_model_config(self): + from openevolve.config import LLMModelConfig + + cfg = LLMModelConfig(max_budget_usd=2.5) + self.assertEqual(cfg.max_budget_usd, 2.5) + + def test_max_budget_usd_default_none(self): + from openevolve.config import LLMModelConfig + + cfg = LLMModelConfig() + self.assertIsNone(cfg.max_budget_usd) + + def test_max_budget_usd_from_dict(self): + from openevolve.config import Config + + config = Config.from_dict( + { + "llm": { + "provider": "claude_code", + "models": [{"name": "sonnet", "max_budget_usd": 3.0, "weight": 1.0}], + } + } + ) + self.assertEqual(config.llm.models[0].max_budget_usd, 3.0) + + +class TestProviderRegistry(unittest.TestCase): + def test_claude_code_in_registry(self): + from openevolve.llm.ensemble import _PROVIDER_REGISTRY + + self.assertIn("claude_code", _PROVIDER_REGISTRY) + + def test_ensemble_creates_claude_code(self): + from openevolve.llm.ensemble import _create_model + + cfg = MagicMock() + cfg.init_client = None + cfg.provider = "claude_code" + cfg.name = "sonnet" + cfg.system_message = None + cfg.max_tokens = 4096 + cfg.timeout = 60 + cfg.weight = 1.0 + cfg.max_budget_usd = 1.0 + cfg.cwd = None + model = _create_model(cfg) + self.assertIsInstance(model, ClaudeCodeLLM) + + +if __name__ == "__main__": + unittest.main()