Skip to content
Merged
5 changes: 3 additions & 2 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _handle_show_config() -> None:
detected = detect_project(project_root)

# Check if config exists or is auto-detected
config_exists = has_existing_config(project_root)
config_exists, _ = has_existing_config(project_root)
status = "Saved config" if config_exists else "Auto-detected (not saved)"

console.print()
Expand Down Expand Up @@ -400,7 +400,8 @@ def _handle_reset_config(confirm: bool = True) -> None:

project_root = Path.cwd()

if not has_existing_config(project_root):
config_exists, _ = has_existing_config(project_root)
if not config_exists:
console.print("[yellow]No Codeflash configuration found to remove.[/yellow]")
return

Expand Down
17 changes: 8 additions & 9 deletions codeflash/cli_cmds/init_javascript.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from git import InvalidGitRepositoryError, Repo
from rich.console import Group
from rich.panel import Panel
from rich.prompt import Confirm
from rich.table import Table
from rich.text import Text

Expand All @@ -26,7 +27,6 @@
from codeflash.code_utils.git_utils import get_git_remotes
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
from codeflash.telemetry.posthog_cf import ph
from rich.prompt import Confirm


class ProjectLanguage(Enum):
Expand Down Expand Up @@ -165,22 +165,21 @@ def get_package_install_command(project_root: Path, package: str, dev: bool = Tr
if dev:
cmd.append("--save-dev")
return cmd
elif pkg_manager == JsPackageManager.YARN:
if pkg_manager == JsPackageManager.YARN:
cmd = ["yarn", "add", package]
if dev:
cmd.append("--dev")
return cmd
elif pkg_manager == JsPackageManager.BUN:
if pkg_manager == JsPackageManager.BUN:
cmd = ["bun", "add", package]
if dev:
cmd.append("--dev")
return cmd
else:
# Default to npm
cmd = ["npm", "install", package]
if dev:
cmd.append("--save-dev")
return cmd
# Default to npm
cmd = ["npm", "install", package]
if dev:
cmd.append("--save-dev")
return cmd


def init_js_project(language: ProjectLanguage) -> None:
Expand Down
36 changes: 18 additions & 18 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,9 +1577,11 @@ def get_opt_review_metrics(

Returns:
Markdown-formatted string with code blocks showing calling functions.

"""
from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.registry import get_language_support
from codeflash.models.models import FunctionParent

start_time = time.perf_counter()

Expand All @@ -1596,19 +1598,19 @@ def get_opt_review_metrics(
else:
function_name, class_name = qualified_name_split[1], qualified_name_split[0]

# Create a FunctionInfo for the function
# Create a FunctionToOptimize for the function
# We don't have full line info here, so we'll use defaults
parents = ()
parents: list[FunctionParent] = []
if class_name:
parents = (ParentInfo(name=class_name, type="ClassDef"),)
parents = [FunctionParent(name=class_name, type="ClassDef")]

func_info = FunctionInfo(
name=function_name,
func_info = FunctionToOptimize(
function_name=function_name,
file_path=file_path,
start_line=1,
end_line=1,
parents=parents,
language=language,
starting_line=1,
ending_line=1,
language=str(language),
)

# Find references using language support
Expand All @@ -1618,9 +1620,7 @@ def get_opt_review_metrics(
return ""

# Format references as markdown code blocks
calling_fns_details = _format_references_as_markdown(
references, file_path, project_root, language
)
calling_fns_details = _format_references_as_markdown(references, file_path, project_root, language)

except Exception as e:
logger.debug(f"Error getting function references: {e}")
Expand All @@ -1631,9 +1631,7 @@ def get_opt_review_metrics(
return calling_fns_details


def _format_references_as_markdown(
references: list, file_path: Path, project_root: Path, language: Language
) -> str:
def _format_references_as_markdown(references: list, file_path: Path, project_root: Path, language: Language) -> str:
"""Format references as markdown code blocks with calling function code.

Args:
Expand All @@ -1644,6 +1642,7 @@ def _format_references_as_markdown(

Returns:
Markdown-formatted string.

"""
# Group references by file
refs_by_file: dict[Path, list] = {}
Expand Down Expand Up @@ -1710,7 +1709,7 @@ def _format_references_as_markdown(
context_len += len(context_code)

if caller_contexts:
fn_call_context += f"```{lang_hint}:{path_relative}\n"
fn_call_context += f"```{lang_hint}:{path_relative.as_posix()}\n"
fn_call_context += "\n".join(caller_contexts)
fn_call_context += "\n```\n"

Expand All @@ -1728,11 +1727,11 @@ def _extract_calling_function(source_code: str, function_name: str, ref_line: in

Returns:
Source code of the function, or None if not found.

"""
if language == Language.PYTHON:
return _extract_calling_function_python(source_code, function_name, ref_line)
else:
return _extract_calling_function_js(source_code, function_name, ref_line)
return _extract_calling_function_js(source_code, function_name, ref_line)


def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
Expand Down Expand Up @@ -1766,6 +1765,7 @@ def _extract_calling_function_js(source_code: str, function_name: str, ref_line:

Returns:
Source code of the function, or None if not found.

"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
Expand Down
30 changes: 11 additions & 19 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def replace_function_definitions_for_language(

"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
from codeflash.languages.base import Language

original_source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
Expand All @@ -523,25 +523,15 @@ def replace_function_definitions_for_language(
and function_to_optimize.ending_line
and function_to_optimize.file_path == module_abspath
):
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
func_info = FunctionInfo(
name=function_to_optimize.function_name,
file_path=module_abspath,
start_line=function_to_optimize.starting_line,
end_line=function_to_optimize.ending_line,
parents=parents,
is_async=function_to_optimize.is_async,
language=language,
)
# Extract just the target function from the optimized code
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(original_source_code, func_info, optimized_func)
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
else:
# Fallback: use the entire optimized code (for simple single-function files)
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
else:
# For helper files or when we don't have precise line info:
# Find each function by name in both original and optimized code
Expand All @@ -559,15 +549,17 @@ def replace_function_definitions_for_language(
# Find the function in current code
func = None
for f in current_functions:
if func_name in (f.qualified_name, f.name):
if func_name in (f.qualified_name, f.function_name):
func = f
break

if func is None:
continue

# Extract just this function from the optimized code
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, func.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(new_code, func, optimized_func)
modified = True
Expand Down Expand Up @@ -606,13 +598,13 @@ def _extract_function_from_code(
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
functions = lang_support.discover_functions_from_source(source_code, file_path)
for func in functions:
if func.name == function_name:
if func.function_name == function_name:
# Extract the function's source using line numbers
# Use doc_start_line if available to include JSDoc/docstring
lines = source_code.splitlines(keepends=True)
effective_start = func.doc_start_line or func.start_line
if effective_start and func.end_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.end_line]
effective_start = func.doc_start_line or func.starting_line
if effective_start and func.ending_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.ending_line]
return "".join(func_lines)
except Exception as e:
logger.debug(f"Error extracting function {function_name}: {e}")
Expand Down
5 changes: 2 additions & 3 deletions codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.formatter import format_code
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
from codeflash.languages.base import Language
from codeflash.languages.registry import get_language_support_by_common_formatters
from codeflash.lsp.helpers import is_LSP_enabled

Expand Down Expand Up @@ -44,9 +43,9 @@ def check_formatter_installed(
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")
return True

if lang_support.language == Language.PYTHON:
if str(lang_support.language) == "python":
tmp_code = """print("hello world")"""
elif lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
elif str(lang_support.language) in ("javascript", "typescript"):
tmp_code = "console.log('hello world');"
else:
return True
Expand Down
7 changes: 6 additions & 1 deletion codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import isort

from codeflash.cli_cmds.console import console, logger
from codeflash.languages.registry import get_language_support
from codeflash.lsp.helpers import is_LSP_enabled


Expand Down Expand Up @@ -43,6 +42,8 @@ def split_lines(text: str) -> list[str]:
def apply_formatter_cmds(
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
) -> tuple[Path, str, bool]:
from codeflash.languages.registry import get_language_support

if not path.exists():
msg = f"File {path} does not exist. Cannot apply formatter commands."
raise FileNotFoundError(msg)
Expand Down Expand Up @@ -90,6 +91,8 @@ def is_diff_line(line: str) -> bool:


def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str:
from codeflash.languages.registry import get_language_support

formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
if formatter_name == "disabled": # nothing to do if no formatter provided
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
Expand All @@ -114,6 +117,8 @@ def format_code(
print_status: bool = True,
exit_on_failure: bool = True,
) -> str:
from codeflash.languages.registry import get_language_support

if is_LSP_enabled():
exit_on_failure = False

Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel, ConfigDict, field_validator

if TYPE_CHECKING:
from codeflash.models.models import FunctionParent
from codeflash.models.function_types import FunctionParent


ObjectDefT = TypeVar("ObjectDefT", ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
Expand Down
19 changes: 2 additions & 17 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001

# Language support imports for multi-language code context extraction
from codeflash.languages import is_python
from codeflash.languages.base import Language
from codeflash.languages import Language, is_python
from codeflash.models.models import (
CodeContextType,
CodeOptimizationContext,
Expand Down Expand Up @@ -234,27 +233,13 @@ def get_code_optimization_context_for_language(

"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, ParentInfo

# Get language support for this function
language = Language(function_to_optimize.language)
lang_support = get_language_support(language)

# Convert FunctionToOptimize to FunctionInfo for language support
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
func_info = FunctionInfo(
name=function_to_optimize.function_name,
file_path=function_to_optimize.file_path,
start_line=function_to_optimize.starting_line or 1,
end_line=function_to_optimize.ending_line or 1,
parents=parents,
is_async=function_to_optimize.is_async,
is_method=len(function_to_optimize.parents) > 0,
language=language,
)

# Extract code context using language support
code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path)
code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path)

# Build imports string if available
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
Expand Down
Loading
Loading