Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c64325c
add ty to dev dependencies
KRRT7 Jan 19, 2026
2d6854a
first working version
KRRT7 Jan 19, 2026
87c44e9
fix: insert global assignments after their function dependencies
KRRT7 Jan 19, 2026
c5f5710
feat: add create-pr CLI command for creating PRs from augmented optim…
KRRT7 Jan 19, 2026
10389b5
fix: transfer module-level compound statements (for/while/with/try) i…
KRRT7 Jan 19, 2026
b234b83
fix: insert compound statements at end of module to avoid NameError
KRRT7 Jan 19, 2026
206fa1c
fix: transfer new function definitions in add_global_assignments
KRRT7 Jan 19, 2026
6936685
pass args
KRRT7 Jan 19, 2026
62291d7
fix: set default git-remote to origin in create-pr command
KRRT7 Jan 19, 2026
a3f889b
feat: store precomputed test reports for create-pr command
KRRT7 Jan 19, 2026
8c1ee8a
Merge branch 'main' into augmented-optimizations
KRRT7 Jan 21, 2026
fe3bf4e
fix: insert new global statements after the globals they depend on
KRRT7 Jan 21, 2026
9850f9f
fix: handle tuple unpacking and chained assignments in dependency tra…
KRRT7 Jan 21, 2026
9efbabd
fix: handle match statements, comprehension/lambda scoping, and circu…
KRRT7 Jan 21, 2026
a9f34ec
Merge branch 'main' into augmented-optimizations
KRRT7 Jan 24, 2026
ded77ac
Merge branch 'main' into augmented-optimizations
KRRT7 Jan 25, 2026
35264f9
Merge branch 'main' into augmented-optimizations
KRRT7 Jan 25, 2026
332eb55
fix: correct variable name in add_global_assignments
KRRT7 Jan 25, 2026
e60e861
refactor: move code_extractor changes to simplify-code-extractor branch
KRRT7 Jan 25, 2026
48430ac
augmented tests
KRRT7 Jan 26, 2026
ccda501
Merge branch 'main' into augmented-optimizations
KRRT7 Jan 28, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,75 @@ def optimize_python_code( # noqa: D417
console.rule()
return []

def augmented_optimize( # noqa: D417
self,
source_code: str,
system_prompt: str,
user_prompt: str,
trace_id: str,
dependency_code: str | None = None,
n_candidates: int = 3,
) -> list[OptimizedCandidate]:
"""Optimize code with custom prompts via /ai/augmented-optimize endpoint.

Parameters
----------
- source_code (str): The python code to optimize (markdown format).
- system_prompt (str): Custom system prompt for the LLM.
- user_prompt (str): Custom user prompt for the LLM.
- trace_id (str): Trace id of optimization run.
- dependency_code (str | None): Optional dependency code context.
- n_candidates (int): Number of candidates to generate (max 3).

Returns
-------
- list[OptimizedCandidate]: A list of Optimization Candidates.

"""
logger.info("Generating augmented optimization candidates…")
console.rule()
start_time = time.perf_counter()
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()

payload = {
"source_code": source_code,
"dependency_code": dependency_code,
"trace_id": trace_id,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"n_candidates": min(n_candidates, 3),
"python_version": platform.python_version(),
"codeflash_version": codeflash_version,
"current_username": get_last_commit_author_if_pr_exists(None),
"repo_owner": git_repo_owner,
"repo_name": git_repo_name,
}
logger.debug(f"Sending augmented optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")

try:
response = self.make_ai_service_request("/augmented-optimize", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating augmented optimization candidates: {e}")
ph("cli-augmented-optimize-error-caught", {"error": str(e)})
console.rule()
return []

if response.status_code == 200:
optimizations_json = response.json()["optimizations"]
end_time = time.perf_counter()
logger.debug(f"!lsp|Generating augmented optimizations took {end_time - start_time:.2f} seconds.")
logger.info(f"!lsp|Received {len(optimizations_json)} augmented optimization candidates.")
console.rule()
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.AUGMENTED)
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating augmented optimization candidates: {response.status_code} - {error}")
ph("cli-augmented-optimize-error-response", {"response_status_code": response.status_code, "error": error})
console.rule()
return []

def get_jit_rewritten_code( # noqa: D417
self, source_code: str, trace_id: str
) -> list[OptimizedCandidate]:
Expand Down Expand Up @@ -635,6 +704,92 @@ def log_results( # noqa: D417
except requests.exceptions.RequestException as e:
logger.exception(f"Error logging features: {e}")

def augmented_generate_tests( # noqa: D417
self,
source_code_being_tested: str,
system_prompt: str,
user_prompt: str,
function_to_optimize: FunctionToOptimize,
helper_function_names: list[str],
module_path: Path,
test_module_path: Path,
test_framework: str,
test_timeout: int,
trace_id: str,
test_index: int,
is_numerical_code: bool | None = None, # noqa: FBT001
) -> tuple[str, str, str] | None:
"""Generate tests with custom prompts via /ai/augmented-testgen.

Parameters
----------
- source_code_being_tested (str): The source code of the function being tested.
- system_prompt (str): Custom system prompt for test generation.
- user_prompt (str): Custom user prompt for test generation.
- function_to_optimize (FunctionToOptimize): The function to optimize.
- helper_function_names (list[str]): List of helper function names.
- module_path (Path): The module path where the function is located.
- test_module_path (Path): The module path for the test code.
- test_framework (str): The test framework to use, e.g., "pytest".
- test_timeout (int): The timeout for each test in seconds.
- trace_id (str): Trace id of optimization run.
- test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id.
- is_numerical_code (bool | None): Whether the code is numerical.

Returns
-------
- tuple[str, str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.

"""
assert test_framework in ["pytest", "unittest"], (
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
)
payload = {
"source_code_being_tested": source_code_being_tested,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"function_to_optimize": function_to_optimize,
"helper_function_names": helper_function_names,
"module_path": module_path,
"test_module_path": test_module_path,
"test_framework": test_framework,
"test_timeout": test_timeout,
"trace_id": trace_id,
"test_index": test_index,
"python_version": platform.python_version(),
"codeflash_version": codeflash_version,
"is_async": function_to_optimize.is_async,
"call_sequence": self.get_next_sequence(),
"is_numerical_code": is_numerical_code,
}
try:
response = self.make_ai_service_request("/augmented-testgen", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating augmented tests: {e}")
ph("cli-augmented-testgen-error-caught", {"error": str(e)})
return None

if response.status_code == 200:
response_json = response.json()
logger.debug(f"Generated augmented tests for function {function_to_optimize.function_name}")
return (
response_json["generated_tests"],
response_json["instrumented_behavior_tests"],
response_json["instrumented_perf_tests"],
)
try:
error = response.json()["error"]
logger.error(f"Error generating augmented tests: {response.status_code} - {error}")
ph("cli-augmented-testgen-error-response", {"response_status_code": response.status_code, "error": error})
return None # noqa: TRY300
except Exception:
logger.error(f"Error generating augmented tests: {response.status_code} - {response.text}")
ph(
"cli-augmented-testgen-error-response",
{"response_status_code": response.status_code, "error": response.text},
)
return None

def generate_regression_tests( # noqa: D417
self,
source_code_being_tested: str,
Expand Down
42 changes: 42 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.cmd_create_pr import create_pr as cmd_create_pr
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.extension import install_vscode_extension
Expand Down Expand Up @@ -57,6 +58,22 @@ def parse_args() -> Namespace:
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
)

# create-pr subcommand for creating PRs from augmented optimization results
create_pr_parser = subparsers.add_parser("create-pr", help="Create a PR from previously applied optimizations")
create_pr_parser.set_defaults(func=cmd_create_pr)
create_pr_parser.add_argument(
"--results-file",
type=str,
default="codeflash_phase1_results.json",
help="Path to augmented output JSON file (default: codeflash_phase1_results.json)",
)
create_pr_parser.add_argument(
"--function", type=str, help="Function name (required if multiple functions in results)"
)
create_pr_parser.add_argument(
"--git-remote", type=str, default="origin", help="Git remote to use for PR creation (default: origin)"
)

parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
parser.add_argument(
Expand Down Expand Up @@ -120,6 +137,22 @@ def parse_args() -> Namespace:
parser.add_argument(
"--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium"
)
parser.add_argument(
"--augmented",
action="store_true",
help="Enable augmented optimization mode for two-phase optimization workflow",
)
parser.add_argument(
"--augmented-prompt-file",
type=str,
help="Path to YAML file with custom system_prompt and user_prompt for Phase 2",
)
parser.add_argument(
"--augmented-output",
type=str,
default="codeflash_phase1_results.json",
help="Path to write Phase 1 results JSON (default: codeflash_phase1_results.json)",
)

args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
Expand Down Expand Up @@ -178,6 +211,15 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
"Async function optimization is now enabled by default."
)

if args.augmented_prompt_file and not args.augmented:
exit_with_message("--augmented-prompt-file requires --augmented flag", error_on_exit=True)

if args.augmented_prompt_file:
prompt_file = Path(args.augmented_prompt_file)
if not prompt_file.exists():
exit_with_message(f"Augmented prompt file {args.augmented_prompt_file} does not exist", error_on_exit=True)
args.augmented_prompt_file = prompt_file.resolve()

return args


Expand Down
147 changes: 147 additions & 0 deletions codeflash/cli_cmds/cmd_create_pr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import TYPE_CHECKING

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.models.models import Phase1Output, TestResults
from codeflash.result.create_pr import check_create_pr
from codeflash.result.explanation import Explanation

if TYPE_CHECKING:
from argparse import Namespace

# Pattern to extract file path from markdown code block: ```python:path/to/file.py
MARKDOWN_FILE_PATH_PATTERN = re.compile(r"```python:([^\n]+)")


def extract_file_path_from_markdown(markdown_code: str) -> str | None:
"""Extract the file path from markdown code block format.

Format: ```python:path/to/file.py
"""
match = MARKDOWN_FILE_PATH_PATTERN.search(markdown_code)
if match:
return match.group(1).strip()
return None


def extract_code_from_markdown(markdown_code: str) -> str:
r"""Extract the code content from markdown code block.

Removes the ```python:path\n ... ``` wrapper.
"""
# Remove opening markdown fence with optional path
code = re.sub(r"^```python(?::[^\n]*)?\n", "", markdown_code)
# Remove closing fence
return re.sub(r"\n```$", "", code)


def create_pr(args: Namespace) -> None:
"""Create a PR from previously applied optimizations."""
results_file = Path(args.results_file)

if not results_file.exists():
exit_with_message(f"Results file not found: {results_file}", error_on_exit=True)

# Load and parse results
with results_file.open(encoding="utf-8") as f:
data = json.load(f)

try:
output = Phase1Output.model_validate(data)
except Exception as e:
exit_with_message(f"Failed to parse results file: {e}", error_on_exit=True)

# Find the function result
if len(output.functions) == 0:
exit_with_message("No functions in results file", error_on_exit=True)

if len(output.functions) > 1 and not args.function:
func_names = [f.function_name for f in output.functions]
exit_with_message(
f"Multiple functions in results. Specify one with --function: {func_names}", error_on_exit=True
)

func_result = output.functions[0]
if args.function:
func_result = next((f for f in output.functions if f.function_name == args.function), None)
if not func_result:
exit_with_message(f"Function {args.function} not found in results", error_on_exit=True)
assert func_result is not None # for type checker - exit_with_message doesn't return

if not func_result.best_candidate_id:
exit_with_message("No successful optimization found in results", error_on_exit=True)

# Get file path - prefer explicit field, fall back to extracting from markdown
file_path_str = func_result.file_path
if not file_path_str:
file_path_str = extract_file_path_from_markdown(func_result.original_source_code)

if not file_path_str:
exit_with_message(
"Could not determine file path from results. Results file may be from an older version of codeflash.",
error_on_exit=True,
)
assert file_path_str is not None # for type checker - exit_with_message doesn't return

file_path = Path(file_path_str)
if not file_path.exists():
exit_with_message(f"Source file not found: {file_path}", error_on_exit=True)

# Read current (optimized) file content
current_content = file_path.read_text(encoding="utf-8")

# Extract original code (strip markdown)
original_code = extract_code_from_markdown(func_result.original_source_code)

# Get the best candidate's explanation
best_explanation = func_result.best_candidate_explanation
if not best_explanation:
# Fall back to the candidate's explanation if the final explanation wasn't captured
best_candidate = next(
(c for c in func_result.candidates if c.optimization_id == func_result.best_candidate_id), None
)
best_explanation = best_candidate.explanation if best_candidate else "Optimization applied"

# Build Explanation object for PR creation
explanation = Explanation(
raw_explanation_message=best_explanation,
winning_behavior_test_results=TestResults(),
winning_benchmarking_test_results=TestResults(),
original_runtime_ns=func_result.original_runtime_ns or 0,
best_runtime_ns=func_result.best_runtime_ns or func_result.original_runtime_ns or 0,
function_name=func_result.function_name,
file_path=file_path,
)

logger.info(f"Creating PR for optimized function: {func_result.function_name}")
logger.info(f"File: {file_path}")
if func_result.best_speedup_ratio:
logger.info(f"Speedup: {func_result.best_speedup_ratio * 100:.1f}%")

# Call existing PR creation
check_create_pr(
original_code={file_path: original_code},
new_code={file_path: current_content},
explanation=explanation,
existing_tests_source=func_result.existing_tests_source or "",
generated_original_test_source="",
function_trace_id=func_result.trace_id,
coverage_message="",
replay_tests=func_result.replay_tests_source or "",
concolic_tests=func_result.concolic_tests_source or "",
optimization_review="",
root_dir=git_root_dir(),
git_remote=getattr(args, "git_remote", None),
precomputed_test_report=func_result.test_report,
precomputed_loop_count=func_result.loop_count,
)

# Cleanup results file after successful PR creation
results_file.unlink()
logger.info(f"Cleaned up results file: {results_file}")
Loading
Loading