diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 96e467072..5b7d49668 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -193,6 +193,75 @@ def optimize_python_code( # noqa: D417 console.rule() return [] + def augmented_optimize( # noqa: D417 + self, + source_code: str, + system_prompt: str, + user_prompt: str, + trace_id: str, + dependency_code: str | None = None, + n_candidates: int = 3, + ) -> list[OptimizedCandidate]: + """Optimize code with custom prompts via /ai/augmented-optimize endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize (markdown format). + - system_prompt (str): Custom system prompt for the LLM. + - user_prompt (str): Custom user prompt for the LLM. + - trace_id (str): Trace id of optimization run. + - dependency_code (str | None): Optional dependency code context. + - n_candidates (int): Number of candidates to generate (max 3). + + Returns + ------- + - list[OptimizedCandidate]: A list of Optimization Candidates. + + """ + logger.info("Generating augmented optimization candidates…") + console.rule() + start_time = time.perf_counter() + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + payload = { + "source_code": source_code, + "dependency_code": dependency_code, + "trace_id": trace_id, + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "n_candidates": min(n_candidates, 3), + "python_version": platform.python_version(), + "codeflash_version": codeflash_version, + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + } + logger.debug(f"Sending augmented optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}") + + try: + response = self.make_ai_service_request("/augmented-optimize", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating augmented optimization candidates: {e}") + ph("cli-augmented-optimize-error-caught", {"error": str(e)}) + console.rule() + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + end_time = time.perf_counter() + logger.debug(f"!lsp|Generating augmented optimizations took {end_time - start_time:.2f} seconds.") + logger.info(f"!lsp|Received {len(optimizations_json)} augmented optimization candidates.") + console.rule() + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.AUGMENTED) + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating augmented optimization candidates: {response.status_code} - {error}") + ph("cli-augmented-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def get_jit_rewritten_code( # noqa: D417 self, source_code: str, trace_id: str ) -> list[OptimizedCandidate]: @@ -635,6 +704,92 @@ def log_results( # noqa: D417 except requests.exceptions.RequestException as e: logger.exception(f"Error logging features: {e}") + def augmented_generate_tests( # noqa: D417 + self, + source_code_being_tested: str, + system_prompt: str, + user_prompt: str, + function_to_optimize: FunctionToOptimize, + helper_function_names: list[str], + module_path: Path, + test_module_path: Path, + test_framework: str, + test_timeout: int, + trace_id: str, + test_index: int, + is_numerical_code: bool | None = None, # noqa: FBT001 + ) -> tuple[str, str, str] | None: + """Generate tests with custom prompts via /ai/augmented-testgen. + + Parameters + ---------- + - source_code_being_tested (str): The source code of the function being tested. + - system_prompt (str): Custom system prompt for test generation. + - user_prompt (str): Custom user prompt for test generation. + - function_to_optimize (FunctionToOptimize): The function to optimize. + - helper_function_names (list[str]): List of helper function names. + - module_path (Path): The module path where the function is located. + - test_module_path (Path): The module path for the test code. + - test_framework (str): The test framework to use, e.g., "pytest". + - test_timeout (int): The timeout for each test in seconds. + - trace_id (str): Trace id of optimization run. + - test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id. + - is_numerical_code (bool | None): Whether the code is numerical. + + Returns + ------- + - tuple[str, str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred. + + """ + assert test_framework in ["pytest", "unittest"], ( + f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'" + ) + payload = { + "source_code_being_tested": source_code_being_tested, + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "function_to_optimize": function_to_optimize, + "helper_function_names": helper_function_names, + "module_path": module_path, + "test_module_path": test_module_path, + "test_framework": test_framework, + "test_timeout": test_timeout, + "trace_id": trace_id, + "test_index": test_index, + "python_version": platform.python_version(), + "codeflash_version": codeflash_version, + "is_async": function_to_optimize.is_async, + "call_sequence": self.get_next_sequence(), + "is_numerical_code": is_numerical_code, + } + try: + response = self.make_ai_service_request("/augmented-testgen", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating augmented tests: {e}") + ph("cli-augmented-testgen-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + response_json = response.json() + logger.debug(f"Generated augmented tests for function {function_to_optimize.function_name}") + return ( + response_json["generated_tests"], + response_json["instrumented_behavior_tests"], + response_json["instrumented_perf_tests"], + ) + try: + error = response.json()["error"] + logger.error(f"Error generating augmented tests: {response.status_code} - {error}") + ph("cli-augmented-testgen-error-response", {"response_status_code": response.status_code, "error": error}) + return None # noqa: TRY300 + except Exception: + logger.error(f"Error generating augmented tests: {response.status_code} - {response.text}") + ph( + "cli-augmented-testgen-error-response", + {"response_status_code": response.status_code, "error": response.text}, + ) + return None + def generate_regression_tests( # noqa: D417 self, source_code_being_tested: str, diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index e135cd022..65702a7f1 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -6,6 +6,7 @@ from codeflash.cli_cmds import logging_config from codeflash.cli_cmds.cli_common import apologize_and_exit +from codeflash.cli_cmds.cmd_create_pr import create_pr as cmd_create_pr from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.extension import install_vscode_extension @@ -57,6 +58,22 @@ def parse_args() -> Namespace: help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.", ) + # create-pr subcommand for creating PRs from augmented optimization results + create_pr_parser = subparsers.add_parser("create-pr", help="Create a PR from previously applied optimizations") + create_pr_parser.set_defaults(func=cmd_create_pr) + create_pr_parser.add_argument( + "--results-file", + type=str, + default="codeflash_phase1_results.json", + help="Path to augmented output JSON file (default: codeflash_phase1_results.json)", + ) + create_pr_parser.add_argument( + "--function", type=str, help="Function name (required if multiple functions in results)" + ) + create_pr_parser.add_argument( + "--git-remote", type=str, default="origin", help="Git remote to use for PR creation (default: origin)" + ) + parser.add_argument("--file", help="Try to optimize only this file") parser.add_argument("--function", help="Try to optimize only this function within the given file path") parser.add_argument( @@ -120,6 +137,22 @@ def parse_args() -> Namespace: parser.add_argument( "--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium" ) + parser.add_argument( + "--augmented", + action="store_true", + help="Enable augmented optimization mode for two-phase optimization workflow", + ) + parser.add_argument( + "--augmented-prompt-file", + type=str, + help="Path to YAML file with custom system_prompt and user_prompt for Phase 2", + ) + parser.add_argument( + "--augmented-output", + type=str, + default="codeflash_phase1_results.json", + help="Path to write Phase 1 results JSON (default: codeflash_phase1_results.json)", + ) args, unknown_args = parser.parse_known_args() sys.argv[:] = [sys.argv[0], *unknown_args] @@ -178,6 +211,15 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace: "Async function optimization is now enabled by default." ) + if args.augmented_prompt_file and not args.augmented: + exit_with_message("--augmented-prompt-file requires --augmented flag", error_on_exit=True) + + if args.augmented_prompt_file: + prompt_file = Path(args.augmented_prompt_file) + if not prompt_file.exists(): + exit_with_message(f"Augmented prompt file {args.augmented_prompt_file} does not exist", error_on_exit=True) + args.augmented_prompt_file = prompt_file.resolve() + return args diff --git a/codeflash/cli_cmds/cmd_create_pr.py b/codeflash/cli_cmds/cmd_create_pr.py new file mode 100644 index 000000000..902c6ff8e --- /dev/null +++ b/codeflash/cli_cmds/cmd_create_pr.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.code_utils import exit_with_message +from codeflash.code_utils.git_utils import git_root_dir +from codeflash.models.models import Phase1Output, TestResults +from codeflash.result.create_pr import check_create_pr +from codeflash.result.explanation import Explanation + +if TYPE_CHECKING: + from argparse import Namespace + +# Pattern to extract file path from markdown code block: ```python:path/to/file.py +MARKDOWN_FILE_PATH_PATTERN = re.compile(r"```python:([^\n]+)") + + +def extract_file_path_from_markdown(markdown_code: str) -> str | None: + """Extract the file path from markdown code block format. + + Format: ```python:path/to/file.py + """ + match = MARKDOWN_FILE_PATH_PATTERN.search(markdown_code) + if match: + return match.group(1).strip() + return None + + +def extract_code_from_markdown(markdown_code: str) -> str: + r"""Extract the code content from markdown code block. + + Removes the ```python:path\n ... ``` wrapper. + """ + # Remove opening markdown fence with optional path + code = re.sub(r"^```python(?::[^\n]*)?\n", "", markdown_code) + # Remove closing fence + return re.sub(r"\n```$", "", code) + + +def create_pr(args: Namespace) -> None: + """Create a PR from previously applied optimizations.""" + results_file = Path(args.results_file) + + if not results_file.exists(): + exit_with_message(f"Results file not found: {results_file}", error_on_exit=True) + + # Load and parse results + with results_file.open(encoding="utf-8") as f: + data = json.load(f) + + try: + output = Phase1Output.model_validate(data) + except Exception as e: + exit_with_message(f"Failed to parse results file: {e}", error_on_exit=True) + + # Find the function result + if len(output.functions) == 0: + exit_with_message("No functions in results file", error_on_exit=True) + + if len(output.functions) > 1 and not args.function: + func_names = [f.function_name for f in output.functions] + exit_with_message( + f"Multiple functions in results. Specify one with --function: {func_names}", error_on_exit=True + ) + + func_result = output.functions[0] + if args.function: + func_result = next((f for f in output.functions if f.function_name == args.function), None) + if not func_result: + exit_with_message(f"Function {args.function} not found in results", error_on_exit=True) + assert func_result is not None # for type checker - exit_with_message doesn't return + + if not func_result.best_candidate_id: + exit_with_message("No successful optimization found in results", error_on_exit=True) + + # Get file path - prefer explicit field, fall back to extracting from markdown + file_path_str = func_result.file_path + if not file_path_str: + file_path_str = extract_file_path_from_markdown(func_result.original_source_code) + + if not file_path_str: + exit_with_message( + "Could not determine file path from results. Results file may be from an older version of codeflash.", + error_on_exit=True, + ) + assert file_path_str is not None # for type checker - exit_with_message doesn't return + + file_path = Path(file_path_str) + if not file_path.exists(): + exit_with_message(f"Source file not found: {file_path}", error_on_exit=True) + + # Read current (optimized) file content + current_content = file_path.read_text(encoding="utf-8") + + # Extract original code (strip markdown) + original_code = extract_code_from_markdown(func_result.original_source_code) + + # Get the best candidate's explanation + best_explanation = func_result.best_candidate_explanation + if not best_explanation: + # Fall back to the candidate's explanation if the final explanation wasn't captured + best_candidate = next( + (c for c in func_result.candidates if c.optimization_id == func_result.best_candidate_id), None + ) + best_explanation = best_candidate.explanation if best_candidate else "Optimization applied" + + # Build Explanation object for PR creation + explanation = Explanation( + raw_explanation_message=best_explanation, + winning_behavior_test_results=TestResults(), + winning_benchmarking_test_results=TestResults(), + original_runtime_ns=func_result.original_runtime_ns or 0, + best_runtime_ns=func_result.best_runtime_ns or func_result.original_runtime_ns or 0, + function_name=func_result.function_name, + file_path=file_path, + ) + + logger.info(f"Creating PR for optimized function: {func_result.function_name}") + logger.info(f"File: {file_path}") + if func_result.best_speedup_ratio: + logger.info(f"Speedup: {func_result.best_speedup_ratio * 100:.1f}%") + + # Call existing PR creation + check_create_pr( + original_code={file_path: original_code}, + new_code={file_path: current_content}, + explanation=explanation, + existing_tests_source=func_result.existing_tests_source or "", + generated_original_test_source="", + function_trace_id=func_result.trace_id, + coverage_message="", + replay_tests=func_result.replay_tests_source or "", + concolic_tests=func_result.concolic_tests_source or "", + optimization_review="", + root_dir=git_root_dir(), + git_remote=getattr(args, "git_remote", None), + precomputed_test_report=func_result.test_report, + precomputed_loop_count=func_result.loop_count, + ) + + # Cleanup results file after successful PR creation + results_file.unlink() + logger.info(f"Cleaned up results file: {results_file}") diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index fe0ff095e..3a1021d54 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -23,13 +23,25 @@ class PrComment: benchmark_details: Optional[list[BenchmarkDetail]] = None original_async_throughput: Optional[int] = None best_async_throughput: Optional[int] = None + # Optional pre-computed values (used by create-pr CLI command) + precomputed_test_report: Optional[dict[str, dict[str, int]]] = None + precomputed_loop_count: Optional[int] = None def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]: - report_table = { - test_type.to_name(): result - for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() - if test_type.to_name() - } + # Use precomputed values if available, otherwise compute from TestResults + if self.precomputed_test_report is not None: + report_table = self.precomputed_test_report + else: + report_table = { + test_type.to_name(): result + for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() + if test_type.to_name() + } + loop_count = ( + self.precomputed_loop_count + if self.precomputed_loop_count is not None + else self.winning_benchmarking_test_results.number_of_loops() + ) result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = { "optimization_explanation": self.optimization_explanation, @@ -39,7 +51,7 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B "file_path": self.relative_file_path, "speedup_x": self.speedup_x, "speedup_pct": self.speedup_pct, - "loop_count": self.winning_benchmarking_test_results.number_of_loops(), + "loop_count": loop_count, "report_table": report_table, "benchmark_details": self.benchmark_details if self.benchmark_details else None, } diff --git a/codeflash/main.py b/codeflash/main.py index 31afd0305..67880a7fc 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -32,7 +32,7 @@ def main() -> None: disable_telemetry = pyproject_config.get("disable_telemetry", False) init_sentry(not disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not disable_telemetry) - args.func() + args.func(args) elif args.verify_setup: args = process_pyproject_config(args) init_sentry(not args.disable_telemetry, exclude_errors=True) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index dc5b82923..2a4d87297 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -529,6 +529,7 @@ class OptimizedCandidateSource(str, Enum): REPAIR = "REPAIR" ADAPTIVE = "ADAPTIVE" JIT_REWRITE = "JIT_REWRITE" + AUGMENTED = "AUGMENTED" @dataclass(frozen=True) @@ -955,3 +956,53 @@ def __eq__(self, other: object) -> bool: return False sys.setrecursionlimit(original_recursion_limit) return True + + +class Phase1CandidateResult(BaseModel): + optimization_id: str + source_code: str + explanation: str + speedup_ratio: Optional[float] = None + runtime_ns: Optional[int] = None + is_correct: bool + line_profiler_results: Optional[str] = None + test_failures: Optional[list[str]] = None + test_diffs: Optional[list[dict]] = None + + +class Phase1FunctionResult(BaseModel): + function_name: str + trace_id: str + original_source_code: str + dependency_code: Optional[str] = None + original_runtime_ns: Optional[int] = None + original_line_profiler_results: Optional[str] = None + candidates: list[Phase1CandidateResult] + best_candidate_id: Optional[str] = None + best_speedup_ratio: Optional[float] = None + # PR creation data - captured after best candidate is selected + file_path: Optional[str] = None + existing_tests_source: Optional[str] = None + replay_tests_source: Optional[str] = None + concolic_tests_source: Optional[str] = None + best_candidate_explanation: Optional[str] = None + best_runtime_ns: Optional[int] = None + # Test results summary for PR creation + test_report: Optional[dict[str, dict[str, int]]] = None # test_type_name -> {passed: int, failed: int} + loop_count: Optional[int] = None + + +class Phase1Output(BaseModel): + codeflash_version: str + timestamp: str + python_version: str + functions: list[Phase1FunctionResult] + total_functions: int + successful_optimizations: int + + +class AugmentedPrompts(BaseModel): + system_prompt: str + user_prompt: str + testgen_system_prompt: str | None = None + testgen_user_prompt: str | None = None diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 256b65b7a..be78daf34 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -4,16 +4,19 @@ import concurrent.futures import logging import os +import platform import queue import random import subprocess import uuid from collections import defaultdict +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Callable import libcst as cst import sentry_sdk +import yaml from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -81,6 +84,7 @@ AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest, AIServiceCodeRepairRequest, + AugmentedPrompts, BestOptimization, CandidateEvaluationContext, CodeOptimizationContext, @@ -92,6 +96,9 @@ OptimizedCandidateResult, OptimizedCandidateSource, OriginalCodeBaseline, + Phase1CandidateResult, + Phase1FunctionResult, + Phase1Output, TestFile, TestFiles, TestingMode, @@ -121,7 +128,7 @@ ) from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests from codeflash.verification.verification_utils import get_test_file_path -from codeflash.verification.verifier import generate_tests +from codeflash.verification.verifier import generate_augmented_tests, generate_tests if TYPE_CHECKING: from argparse import Namespace @@ -475,6 +482,122 @@ def __init__( self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function self.is_numerical_code: bool | None = None + self.augmented_mode = getattr(args, "augmented", False) if args else False + self.augmented_prompt_file = getattr(args, "augmented_prompt_file", None) if args else None + self.augmented_output = getattr(args, "augmented_output", "codeflash_phase1_results.json") if args else None + self.augmented_prompts: AugmentedPrompts | None = None + self.phase1_candidate_results: list[Phase1CandidateResult] = [] + # PR data captured during process_review for Phase1 output + self.phase1_existing_tests: str | None = None + self.phase1_replay_tests: str | None = None + self.phase1_concolic_tests: str | None = None + self.phase1_best_explanation: str | None = None + self.phase1_best_runtime: int | None = None + self.phase1_test_report: dict[str, dict[str, int]] | None = None + self.phase1_loop_count: int | None = None + + def load_augmented_prompts(self) -> AugmentedPrompts | None: + if not self.augmented_prompt_file: + return None + prompt_path = Path(self.augmented_prompt_file) + if not prompt_path.exists(): + logger.error(f"Augmented prompt file not found: {prompt_path}") + return None + with prompt_path.open(encoding="utf-8") as f: + data = yaml.safe_load(f) + if not data or "system_prompt" not in data or "user_prompt" not in data: + logger.error("Augmented prompt file must contain 'system_prompt' and 'user_prompt' keys") + return None + return AugmentedPrompts( + system_prompt=data["system_prompt"], + user_prompt=data["user_prompt"], + testgen_system_prompt=data.get("testgen_system_prompt"), + testgen_user_prompt=data.get("testgen_user_prompt"), + ) + + def collect_phase1_candidate_result( + self, + candidate: OptimizedCandidate, + eval_ctx: CandidateEvaluationContext, + test_failures: list[str] | None = None, + test_diffs: list[dict] | None = None, + ) -> Phase1CandidateResult: + speedup = eval_ctx.get_speedup_ratio(candidate.optimization_id) + runtime = eval_ctx.get_optimized_runtime(candidate.optimization_id) + is_correct = eval_ctx.is_correct.get(candidate.optimization_id, False) + line_profiler = eval_ctx.optimized_line_profiler_results.get(candidate.optimization_id) + return Phase1CandidateResult( + optimization_id=candidate.optimization_id, + source_code=candidate.source_code.markdown, + explanation=candidate.explanation, + speedup_ratio=speedup, + runtime_ns=int(runtime) if runtime else None, + is_correct=is_correct, + line_profiler_results=line_profiler, + test_failures=test_failures, + test_diffs=test_diffs, + ) + + def get_phase1_function_result( + self, + code_context: CodeOptimizationContext, + original_runtime_ns: int | None, + original_line_profiler_results: str | None, + best_candidate_id: str | None, + best_speedup_ratio: float | None, + ) -> Phase1FunctionResult: + return Phase1FunctionResult( + function_name=self.function_to_optimize.function_name, + trace_id=self.function_trace_id, + original_source_code=code_context.read_writable_code.markdown, + dependency_code=code_context.read_only_context_code if code_context.read_only_context_code else None, + original_runtime_ns=original_runtime_ns, + original_line_profiler_results=original_line_profiler_results, + candidates=self.phase1_candidate_results, + best_candidate_id=best_candidate_id, + best_speedup_ratio=best_speedup_ratio, + # PR creation data + file_path=self.function_to_optimize.file_path.as_posix(), + existing_tests_source=self.phase1_existing_tests, + replay_tests_source=self.phase1_replay_tests, + concolic_tests_source=self.phase1_concolic_tests, + best_candidate_explanation=self.phase1_best_explanation, + best_runtime_ns=self.phase1_best_runtime, + test_report=self.phase1_test_report, + loop_count=self.phase1_loop_count, + ) + + def write_phase1_output(self, function_result: Phase1FunctionResult) -> None: + from codeflash.version import __version__ as codeflash_version + + output_path = Path(self.augmented_output) if self.augmented_output else Path("codeflash_phase1_results.json") + output = Phase1Output( + codeflash_version=codeflash_version, + timestamp=datetime.now(tz=timezone.utc).isoformat(), + python_version=platform.python_version(), + functions=[function_result], + total_functions=1, + successful_optimizations=1 if function_result.best_candidate_id else 0, + ) + with output_path.open("w", encoding="utf-8") as f: + f.write(output.model_dump_json(indent=2)) + logger.info(f"Phase 1 results written to {output_path}") + + def get_augmented_candidates(self, code_context: CodeOptimizationContext) -> list[OptimizedCandidate]: + if not self.augmented_prompts: + logger.error("Augmented prompts not loaded") + return [] + candidates = self.aiservice_client.augmented_optimize( + source_code=code_context.read_writable_code.markdown, + system_prompt=self.augmented_prompts.system_prompt, + user_prompt=self.augmented_prompts.user_prompt, + trace_id=self.function_trace_id, + dependency_code=code_context.read_only_context_code if code_context.read_only_context_code else None, + n_candidates=3, + ) + logger.info(f"Received {len(candidates)} augmented optimization candidates") + return candidates + def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}") @@ -642,6 +765,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: read_only_context_code=code_context.read_only_context_code, run_experiment=should_run_experiment, is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, + code_context=code_context, ) concurrent.futures.wait([future_tests, future_optimizations]) @@ -709,6 +833,23 @@ def optimize_function(self) -> Result[BestOptimization, str]: if self.args.override_fixtures: restore_conftest(original_conftest_content) + + if self.augmented_mode: + function_result = self.get_phase1_function_result( + code_context=code_context, + original_runtime_ns=original_code_baseline.runtime if original_code_baseline else None, + original_line_profiler_results=original_code_baseline.line_profile_results.get("str_out") + if original_code_baseline and original_code_baseline.line_profile_results + else None, + best_candidate_id=best_optimization.candidate.optimization_id if best_optimization else None, + best_speedup_ratio=performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime + ) + if best_optimization + else None, + ) + self.write_phase1_output(function_result) + if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) @@ -1223,6 +1364,11 @@ def determine_best_candidate( exp_type=exp_type, ) + if self.augmented_mode: + for candidate in candidates: + phase1_result = self.collect_phase1_candidate_result(candidate=candidate, eval_ctx=eval_ctx) + self.phase1_candidate_results.append(phase1_result) + return best_optimization def call_adaptive_optimize( @@ -1629,8 +1775,31 @@ def generate_optimizations( read_only_context_code: str, run_experiment: bool = False, # noqa: FBT001, FBT002 is_numerical_code: bool | None = None, # noqa: FBT001 + code_context: CodeOptimizationContext | None = None, ) -> Result[tuple[OptimizationSet, str], str]: """Generate optimization candidates for the function. Backend handles multi-model diversity.""" + if self.augmented_mode and self.augmented_prompt_file: + self.augmented_prompts = self.load_augmented_prompts() + if not self.augmented_prompts: + return Failure("Failed to load augmented prompts from file") + if code_context is None: + return Failure("Code context required for augmented mode") + augmented_candidates = self.get_augmented_candidates(code_context) + if not augmented_candidates: + return Failure( + f"/!\\ NO AUGMENTED OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}" + ) + future_references = self.executor.submit( + get_opt_review_metrics, + self.function_to_optimize_source_code, + self.function_to_optimize.file_path, + self.function_to_optimize.qualified_name, + self.project_root, + self.test_cfg.tests_root, + ) + function_references = future_references.result() + return Success((OptimizationSet(control=augmented_candidates, experiment=None), function_references)) + n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.effort) future_optimization_candidates = self.executor.submit( self.aiservice_client.optimize_python_code, @@ -1989,6 +2158,21 @@ def process_review( best_optimization.explanation_v2 = new_explanation.explanation_message() + # Capture PR data for Phase1 output in augmented mode + if self.augmented_mode: + self.phase1_existing_tests = existing_tests + self.phase1_replay_tests = replay_tests + self.phase1_concolic_tests = concolic_tests + self.phase1_best_explanation = new_explanation.explanation_message() + self.phase1_best_runtime = best_optimization.runtime + # Capture test results for PR creation + self.phase1_test_report = { + tt.to_name(): counts + for tt, counts in new_explanation.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() + if tt.to_name() # Skip empty names + } + self.phase1_loop_count = new_explanation.winning_benchmarking_test_results.number_of_loops() + data = { "original_code": original_code_combined, "new_code": new_code_combined, @@ -2556,26 +2740,53 @@ def submit_test_generation_tasks( generated_test_paths: list[Path], generated_perf_test_paths: list[Path], ) -> list[concurrent.futures.Future]: - return [ - executor.submit( - generate_tests, - self.aiservice_client, - source_code_being_tested, - self.function_to_optimize, - helper_function_names, - Path(self.original_module_path), - self.test_cfg, - INDIVIDUAL_TESTCASE_TIMEOUT, - self.function_trace_id, - test_index, - test_path, - test_perf_path, - self.is_numerical_code, - ) - for test_index, (test_path, test_perf_path) in enumerate( - zip(generated_test_paths, generated_perf_test_paths) - ) - ] + # Check if augmented testgen prompts are available (Round 2+ with custom test prompts) + use_augmented_testgen = ( + self.augmented_prompts is not None + and self.augmented_prompts.testgen_system_prompt is not None + and self.augmented_prompts.testgen_user_prompt is not None + ) + + futures = [] + for test_index, (test_path, test_perf_path) in enumerate(zip(generated_test_paths, generated_perf_test_paths)): + if use_augmented_testgen: + # Use augmented test generation with custom prompts from cc-plugin + future = executor.submit( + generate_augmented_tests, + self.aiservice_client, + source_code_being_tested, + self.augmented_prompts.testgen_system_prompt, + self.augmented_prompts.testgen_user_prompt, + self.function_to_optimize, + helper_function_names, + Path(self.original_module_path), + self.test_cfg, + INDIVIDUAL_TESTCASE_TIMEOUT, + self.function_trace_id, + test_index, + test_path, + test_perf_path, + self.is_numerical_code, + ) + else: + # Use normal test generation (Round 1 or no custom prompts) + future = executor.submit( + generate_tests, + self.aiservice_client, + source_code_being_tested, + self.function_to_optimize, + helper_function_names, + Path(self.original_module_path), + self.test_cfg, + INDIVIDUAL_TESTCASE_TIMEOUT, + self.function_trace_id, + test_index, + test_path, + test_perf_path, + self.is_numerical_code, + ) + futures.append(future) + return futures def cleanup_generated_files(self) -> None: paths_to_cleanup = [] diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index f888f710a..da3ca1a72 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -186,6 +186,8 @@ def check_create_pr( root_dir: Path, git_remote: Optional[str] = None, optimization_review: str = "", + precomputed_test_report: Optional[dict[str, dict[str, int]]] = None, + precomputed_loop_count: Optional[int] = None, ) -> None: pr_number: Optional[int] = env_utils.get_pr_number() git_repo = git.Repo(search_parent_directories=True) @@ -222,6 +224,8 @@ def check_create_pr( benchmark_details=explanation.benchmark_details, original_async_throughput=explanation.original_async_throughput, best_async_throughput=explanation.best_async_throughput, + precomputed_test_report=precomputed_test_report, + precomputed_loop_count=precomputed_loop_count, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, @@ -274,6 +278,8 @@ def check_create_pr( benchmark_details=explanation.benchmark_details, original_async_throughput=explanation.original_async_throughput, best_async_throughput=explanation.best_async_throughput, + precomputed_test_report=precomputed_test_report, + precomputed_loop_count=precomputed_loop_count, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index f60718020..e44aa8a56 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -69,6 +69,66 @@ def generate_tests( ) +def generate_augmented_tests( + aiservice_client: AiServiceClient, + source_code_being_tested: str, + system_prompt: str, + user_prompt: str, + function_to_optimize: FunctionToOptimize, + helper_function_names: list[str], + module_path: Path, + test_cfg: TestConfig, + test_timeout: int, + function_trace_id: str, + test_index: int, + test_path: Path, + test_perf_path: Path, + is_numerical_code: bool | None = None, # noqa: FBT001 +) -> tuple[str, str, Path] | None: + """Generate tests with custom prompts for augmented mode. + + This mirrors generate_tests but uses custom system/user prompts from the cc-plugin. + """ + start_time = time.perf_counter() + test_module_path = Path(module_name_from_file_path(test_path, test_cfg.tests_project_rootdir)) + response = aiservice_client.augmented_generate_tests( + source_code_being_tested=source_code_being_tested, + system_prompt=system_prompt, + user_prompt=user_prompt, + function_to_optimize=function_to_optimize, + helper_function_names=helper_function_names, + module_path=module_path, + test_module_path=test_module_path, + test_framework=test_cfg.test_framework, + test_timeout=test_timeout, + trace_id=function_trace_id, + test_index=test_index, + is_numerical_code=is_numerical_code, + ) + if response and isinstance(response, tuple) and len(response) == 3: + generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = response + temp_run_dir = get_run_tmp_file(Path()).as_posix() + + instrumented_behavior_test_source = instrumented_behavior_test_source.replace( + "{codeflash_run_tmp_dir_client_side}", temp_run_dir + ) + instrumented_perf_test_source = instrumented_perf_test_source.replace( + "{codeflash_run_tmp_dir_client_side}", temp_run_dir + ) + else: + logger.warning(f"Failed to generate augmented tests for {function_to_optimize.function_name}") + return None + end_time = time.perf_counter() + logger.debug(f"Generated augmented tests in {end_time - start_time:.2f} seconds") + return ( + generated_test_source, + instrumented_behavior_test_source, + instrumented_perf_test_source, + test_path, + test_perf_path, + ) + + def merge_unit_tests(unit_test_source: str, inspired_unit_tests: str, test_framework: str) -> str: try: inspired_unit_tests_ast = ast.parse(inspired_unit_tests) diff --git a/pyproject.toml b/pyproject.toml index 1714532d0..9c6d8f312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "codeflash-benchmark", "filelock", "pytest-asyncio>=1.2.0", + "pyyaml>=6.0.3", ] [project.urls] @@ -76,6 +77,7 @@ dev = [ "types-unidiff>=0.7.0.20240505,<0.8", "uv>=0.6.2", "pre-commit>=4.2.0,<5", + "ty>=0.0.12", ] tests = [ "black>=25.9.0", diff --git a/uv.lock b/uv.lock index 7012850eb..e055ccb05 100644 --- a/uv.lock +++ b/uv.lock @@ -433,6 +433,7 @@ dependencies = [ { name = "pytest-asyncio", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pytest-asyncio", version = "1.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "pytest-timeout" }, + { name = "pyyaml" }, { name = "rich" }, { name = "sentry-sdk" }, { name = "tomlkit" }, @@ -451,6 +452,7 @@ dev = [ { name = "pre-commit", version = "4.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pre-commit", version = "4.5.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "ruff" }, + { name = "ty" }, { name = "types-cffi" }, { name = "types-colorama" }, { name = "types-decorator" }, @@ -517,6 +519,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-timeout", specifier = ">=2.1.0" }, + { name = "pyyaml", specifier = ">=6.0.3" }, { name = "rich", specifier = ">=13.8.1" }, { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "tomlkit", specifier = ">=0.11.7" }, @@ -531,6 +534,7 @@ dev = [ { name = "pandas-stubs", specifier = ">=2.2.2.240807,<2.2.3.241009" }, { name = "pre-commit", specifier = ">=4.2.0,<5" }, { name = "ruff", specifier = ">=0.7.0" }, + { name = "ty", specifier = ">=0.0.12" }, { name = "types-cffi", specifier = ">=1.16.0.20240331" }, { name = "types-colorama", specifier = ">=0.4.15.20240311" }, { name = "types-decorator", specifier = ">=5.1.8.20240310" }, @@ -5154,6 +5158,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/b5/b0d3d8b901b6a04ca38df5e24c27e53afb15b93624d7fd7d658c7cd9352a/triton-3.5.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bac7f7d959ad0f48c0e97d6643a1cc0fd5786fe61cb1f83b537c6b2d54776478", size = 170582192, upload-time = "2025-11-11T17:41:23.963Z" }, ] +[[package]] +name = "ty" +version = "0.0.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/78/ba1a4ad403c748fbba8be63b7e774a90e80b67192f6443d624c64fe4aaab/ty-0.0.12.tar.gz", hash = "sha256:cd01810e106c3b652a01b8f784dd21741de9fdc47bd595d02c122a7d5cefeee7", size = 4981303, upload-time = "2026-01-14T22:30:48.537Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/8f/c21314d074dda5fb13d3300fa6733fd0d8ff23ea83a721818740665b6314/ty-0.0.12-py3-none-linux_armv6l.whl", hash = "sha256:eb9da1e2c68bd754e090eab39ed65edf95168d36cbeb43ff2bd9f86b4edd56d1", size = 9614164, upload-time = "2026-01-14T22:30:44.016Z" }, + { url = "https://files.pythonhosted.org/packages/09/28/f8a4d944d13519d70c486e8f96d6fa95647ac2aa94432e97d5cfec1f42f6/ty-0.0.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c181f42aa19b0ed7f1b0c2d559980b1f1d77cc09419f51c8321c7ddf67758853", size = 9542337, upload-time = "2026-01-14T22:30:05.687Z" }, + { url = "https://files.pythonhosted.org/packages/e1/9c/f576e360441de7a8201daa6dc4ebc362853bc5305e059cceeb02ebdd9a48/ty-0.0.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1f829e1eecd39c3e1b032149db7ae6a3284f72fc36b42436e65243a9ed1173db", size = 8909582, upload-time = "2026-01-14T22:30:46.089Z" }, + { url = "https://files.pythonhosted.org/packages/d6/13/0898e494032a5d8af3060733d12929e3e7716db6c75eac63fa125730a3e7/ty-0.0.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f45162e7826e1789cf3374627883cdeb0d56b82473a0771923e4572928e90be3", size = 9384932, upload-time = "2026-01-14T22:30:13.769Z" }, + { url = "https://files.pythonhosted.org/packages/e4/1a/b35b6c697008a11d4cedfd34d9672db2f0a0621ec80ece109e13fca4dfef/ty-0.0.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d11fec40b269bec01e751b2337d1c7ffa959a2c2090a950d7e21c2792442cccd", size = 9453140, upload-time = "2026-01-14T22:30:11.131Z" }, + { url = "https://files.pythonhosted.org/packages/dd/1e/71c9edbc79a3c88a0711324458f29c7dbf6c23452c6e760dc25725483064/ty-0.0.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09d99e37e761a4d2651ad9d5a610d11235fbcbf35dc6d4bc04abf54e7cf894f1", size = 9960680, upload-time = "2026-01-14T22:30:33.621Z" }, + { url = "https://files.pythonhosted.org/packages/0e/75/39375129f62dd22f6ad5a99cd2a42fd27d8b91b235ce2db86875cdad397d/ty-0.0.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d9ca0cdb17bd37397da7b16a7cd23423fc65c3f9691e453ad46c723d121225a1", size = 10904518, upload-time = "2026-01-14T22:30:08.464Z" }, + { url = "https://files.pythonhosted.org/packages/32/5e/26c6d88fafa11a9d31ca9f4d12989f57782ec61e7291d4802d685b5be118/ty-0.0.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcf2757b905e7eddb7e456140066335b18eb68b634a9f72d6f54a427ab042c64", size = 10525001, upload-time = "2026-01-14T22:30:16.454Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a5/2f0b91894af13187110f9ad7ee926d86e4e6efa755c9c88a820ed7f84c85/ty-0.0.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00cf34c1ebe1147efeda3021a1064baa222c18cdac114b7b050bbe42deb4ca80", size = 10307103, upload-time = "2026-01-14T22:30:41.221Z" }, + { url = "https://files.pythonhosted.org/packages/4b/77/13d0410827e4bc713ebb7fdaf6b3590b37dcb1b82e0a81717b65548f2442/ty-0.0.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb3a655bd869352e9a22938d707631ac9fbca1016242b1f6d132d78f347c851", size = 10072737, upload-time = "2026-01-14T22:30:51.783Z" }, + { url = "https://files.pythonhosted.org/packages/e1/dd/fc36d8bac806c74cf04b4ca735bca14d19967ca84d88f31e121767880df1/ty-0.0.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4658e282c7cb82be304052f8f64f9925f23c3c4f90eeeb32663c74c4b095d7ba", size = 9368726, upload-time = "2026-01-14T22:30:18.683Z" }, + { url = "https://files.pythonhosted.org/packages/54/70/9e8e461647550f83e2fe54bc632ccbdc17a4909644783cdbdd17f7296059/ty-0.0.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:c167d838eaaa06e03bb66a517f75296b643d950fbd93c1d1686a187e5a8dbd1f", size = 9454704, upload-time = "2026-01-14T22:30:22.759Z" }, + { url = "https://files.pythonhosted.org/packages/04/9b/6292cf7c14a0efeca0539cf7d78f453beff0475cb039fbea0eb5d07d343d/ty-0.0.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2956e0c9ab7023533b461d8a0e6b2ea7b78e01a8dde0688e8234d0fce10c4c1c", size = 9649829, upload-time = "2026-01-14T22:30:31.234Z" }, + { url = "https://files.pythonhosted.org/packages/49/bd/472a5d2013371e4870886cff791c94abdf0b92d43d305dd0f8e06b6ff719/ty-0.0.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5c6a3fd7479580009f21002f3828320621d8a82d53b7ba36993234e3ccad58c8", size = 10162814, upload-time = "2026-01-14T22:30:36.174Z" }, + { url = "https://files.pythonhosted.org/packages/31/e9/2ecbe56826759845a7c21d80aa28187865ea62bc9757b056f6cbc06f78ed/ty-0.0.12-py3-none-win32.whl", hash = "sha256:a91c24fd75c0f1796d8ede9083e2c0ec96f106dbda73a09fe3135e075d31f742", size = 9140115, upload-time = "2026-01-14T22:30:38.903Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6d/d9531eff35a5c0ec9dbc10231fac21f9dd6504814048e81d6ce1c84dc566/ty-0.0.12-py3-none-win_amd64.whl", hash = "sha256:df151894be55c22d47068b0f3b484aff9e638761e2267e115d515fcc9c5b4a4b", size = 9884532, upload-time = "2026-01-14T22:30:25.112Z" }, + { url = "https://files.pythonhosted.org/packages/e9/f3/20b49e75967023b123a221134548ad7000f9429f13fdcdda115b4c26305f/ty-0.0.12-py3-none-win_arm64.whl", hash = "sha256:cea99d334b05629de937ce52f43278acf155d3a316ad6a35356635f886be20ea", size = 9313974, upload-time = "2026-01-14T22:30:27.44Z" }, +] + [[package]] name = "types-cffi" version = "1.17.0.20250915"