Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 131 additions & 48 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,19 +628,22 @@ def _add_global_declarations_for_language(
Finds module-level declarations (const, let, var, class, type, interface, enum)
in the optimized code that don't exist in the original source and adds them.

New declarations are inserted after any existing declarations they depend on.
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
is already declared in the original source, `_has` will be inserted after `FOO`.

Args:
optimized_code: The optimized code that may contain new declarations.
original_source: The original source code.
module_abspath: Path to the module file (for parser selection).
language: The language of the code.

Returns:
Original source with new declarations added after imports.
Original source with new declarations added in dependency order.

"""
from codeflash.languages.base import Language

# Only process JavaScript/TypeScript
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
return original_source

Expand All @@ -649,84 +652,164 @@ def _add_global_declarations_for_language(

analyzer = get_analyzer_for_file(module_abspath)

# Find declarations in both original and optimized code
original_declarations = analyzer.find_module_level_declarations(original_source)
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)

if not optimized_declarations:
return original_source

# Get names of existing declarations
existing_names = {decl.name for decl in original_declarations}

# Also exclude names that are already imported (to avoid duplicating imported types)
original_imports = analyzer.find_imports(original_source)
for imp in original_imports:
# Add default import name
if imp.default_import:
existing_names.add(imp.default_import)
# Add named imports (use alias if present, otherwise use original name)
for name, alias in imp.named_imports:
existing_names.add(alias if alias else name)
# Add namespace import
if imp.namespace_import:
existing_names.add(imp.namespace_import)

# Find new declarations (names that don't exist in original)
new_declarations = []
seen_sources = set() # Track to avoid duplicates from destructuring
for decl in optimized_declarations:
if decl.name not in existing_names and decl.source_code not in seen_sources:
new_declarations.append(decl)
seen_sources.add(decl.source_code)
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)

if not new_declarations:
return original_source

# Sort by line number to maintain order
new_declarations.sort(key=lambda d: d.start_line)

# Find insertion point (after imports)
lines = original_source.splitlines(keepends=True)
insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source)

# Build new declarations string
new_decl_code = "\n".join(decl.source_code for decl in new_declarations)
new_decl_code = new_decl_code + "\n\n"
# Build a map of existing declaration names to their end lines (1-indexed)
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}

# Insert declarations
before = lines[:insertion_line]
after = lines[insertion_line:]
result_lines = [*before, new_decl_code, *after]
# Insert each new declaration after its dependencies
result = original_source
for decl in new_declarations:
result = _insert_declaration_after_dependencies(
result, decl, existing_decl_end_lines, analyzer, module_abspath
)
# Update the map with the newly inserted declaration for subsequent insertions
# Re-parse to get accurate line numbers after insertion
updated_declarations = analyzer.find_module_level_declarations(result)
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}

return "".join(result_lines)
return result

except Exception as e:
logger.debug(f"Error adding global declarations: {e}")
return original_source


def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
"""Find the line index where new declarations should be inserted (after imports).
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
"""Get all names that already exist in the original source (declarations + imports)."""
existing_names = {decl.name for decl in original_declarations}

original_imports = analyzer.find_imports(original_source)
for imp in original_imports:
if imp.default_import:
existing_names.add(imp.default_import)
for name, alias in imp.named_imports:
existing_names.add(alias if alias else name)
if imp.namespace_import:
existing_names.add(imp.namespace_import)

return existing_names


def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
"""Filter declarations to only those that don't exist in the original source."""
new_declarations = []
seen_sources: set[str] = set()

# Sort by line number to maintain order from optimized code
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)

for decl in sorted_declarations:
if decl.name not in existing_names and decl.source_code not in seen_sources:
new_declarations.append(decl)
seen_sources.add(decl.source_code)

return new_declarations


def _insert_declaration_after_dependencies(
source: str,
declaration,
existing_decl_end_lines: dict[str, int],
analyzer: TreeSitterAnalyzer,
module_abspath: Path,
) -> str:
"""Insert a declaration after the last existing declaration it depends on.

Args:
source: Current source code.
declaration: The declaration to insert.
existing_decl_end_lines: Map of existing declaration names to their end lines.
analyzer: TreeSitter analyzer.
module_abspath: Path to the module file.

Returns:
Source code with the declaration inserted at the correct position.

"""
# Find identifiers referenced in this declaration
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)

# Find the latest end line among all referenced declarations
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)

lines = source.splitlines(keepends=True)

# Ensure proper spacing
decl_code = declaration.source_code
if not decl_code.endswith("\n"):
decl_code += "\n"

# Add blank line before if inserting after content
if insertion_line > 0 and lines[insertion_line - 1].strip():
decl_code = "\n" + decl_code

before = lines[:insertion_line]
after = lines[insertion_line:]

return "".join([*before, decl_code, *after])


def _find_insertion_line_for_declaration(
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
) -> int:
"""Find the line where a declaration should be inserted based on its dependencies.

Args:
source: Source code.
referenced_names: Names referenced by the declaration.
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
analyzer: TreeSitter analyzer.

Returns:
Line index (0-based) where the declaration should be inserted.

"""
# Find the maximum end line among referenced declarations
max_dependency_line = 0
for name in referenced_names:
if name in existing_decl_end_lines:
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])

if max_dependency_line > 0:
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
return max_dependency_line

# No dependencies found - insert after imports
lines = source.splitlines(keepends=True)
return _find_line_after_imports(lines, analyzer, source)


def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
"""Find the line index after all imports.

Args:
lines: Source lines.
analyzer: TreeSitter analyzer for the file.
analyzer: TreeSitter analyzer.
source: Full source code.

Returns:
Line index (0-based) for insertion.
Line index (0-based) for insertion after imports.

"""
try:
imports = analyzer.find_imports(source)
if imports:
# Find the last import's end line
return max(imp.end_line for imp in imports)
except Exception as exc:
logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}")
logger.debug(f"Exception in _find_line_after_imports: {exc}")

# Default: insert at beginning (after any shebang/directive comments)
# Default: insert at beginning (after shebang/directive comments)
for i, line in enumerate(lines):
stripped = line.strip()
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
Expand Down
2 changes: 1 addition & 1 deletion codeflash/result/create_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def check_create_pr(
function_trace_id: str,
coverage_message: str,
replay_tests: str,
concolic_tests: str,
root_dir: Path,
concolic_tests: str = "",
git_remote: Optional[str] = None,
optimization_review: str = "",
original_line_profiler: str | None = None,
Expand Down
78 changes: 78 additions & 0 deletions tests/test_languages/test_js_code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,84 @@ class DataProcessor<T> {



class TestNewVariableFromOptimizedCode:
"""Tests for handling new variables introduced in optimized code."""

def test_new_bound_method_variable_added_after_referenced_constant(self, ts_support, temp_project):
"""Test that a new variable binding a method is added after the constant it references.

When optimized code introduces a new module-level variable (like `_has`) that
references an existing constant (like `CODEFLASH_EMPLOYEE_GITHUB_IDS`), the
replacement should:
1. Add the new variable after the constant it references
2. Replace the function with the optimized version
"""
from codeflash.models.models import CodeStringsMarkdown, CodeString

original_source = '''\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mohammedahmed18 this looks sensitive

"1234",
]);

export function isCodeflashEmployee(userId: string): boolean {
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
}
'''
file_path = temp_project / "auth.ts"
file_path.write_text(original_source, encoding="utf-8")

# Optimized code introduces a bound method variable for performance
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
CODEFLASH_EMPLOYEE_GITHUB_IDS
);

export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''

code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code,
file_path=Path("auth.ts"),
language="typescript"
)
],
language="typescript"
)

replaced = replace_function_definitions_for_language(
["isCodeflashEmployee"],
code_markdown,
file_path,
temp_project,
)

assert replaced
result = file_path.read_text()

# Expected result for strict equality check
expected_result = '''\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
"1234",
]);

const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
CODEFLASH_EMPLOYEE_GITHUB_IDS
);

export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''
assert result == expected_result, (
f"Result does not match expected output.\n"
f"Expected:\n{expected_result}\n\n"
f"Got:\n{result}"
)


class TestImportedTypeNotDuplicated:
"""Tests to ensure imported types are not duplicated during code replacement.

Expand Down
Loading