From 067ef3920c004d4943acddcd1d9d3dbec4a67c6e Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 2 Feb 2026 16:00:29 +0200 Subject: [PATCH 1/3] add new global definitions after their deps --- codeflash/code_utils/code_replacer.py | 179 +++++++++++++----- tests/test_languages/test_js_code_replacer.py | 90 +++++++++ 2 files changed, 221 insertions(+), 48 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 6a57b61e1..87cb3c674 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -628,6 +628,10 @@ def _add_global_declarations_for_language( Finds module-level declarations (const, let, var, class, type, interface, enum) in the optimized code that don't exist in the original source and adds them. + New declarations are inserted after any existing declarations they depend on. + For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO` + is already declared in the original source, `_has` will be inserted after `FOO`. + Args: optimized_code: The optimized code that may contain new declarations. original_source: The original source code. @@ -635,12 +639,11 @@ def _add_global_declarations_for_language( language: The language of the code. Returns: - Original source with new declarations added after imports. + Original source with new declarations added in dependency order. """ from codeflash.languages.base import Language - # Only process JavaScript/TypeScript if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT): return original_source @@ -649,84 +652,164 @@ def _add_global_declarations_for_language( analyzer = get_analyzer_for_file(module_abspath) - # Find declarations in both original and optimized code original_declarations = analyzer.find_module_level_declarations(original_source) optimized_declarations = analyzer.find_module_level_declarations(optimized_code) if not optimized_declarations: return original_source - # Get names of existing declarations - existing_names = {decl.name for decl in original_declarations} - - # Also exclude names that are already imported (to avoid duplicating imported types) - original_imports = analyzer.find_imports(original_source) - for imp in original_imports: - # Add default import name - if imp.default_import: - existing_names.add(imp.default_import) - # Add named imports (use alias if present, otherwise use original name) - for name, alias in imp.named_imports: - existing_names.add(alias if alias else name) - # Add namespace import - if imp.namespace_import: - existing_names.add(imp.namespace_import) - - # Find new declarations (names that don't exist in original) - new_declarations = [] - seen_sources = set() # Track to avoid duplicates from destructuring - for decl in optimized_declarations: - if decl.name not in existing_names and decl.source_code not in seen_sources: - new_declarations.append(decl) - seen_sources.add(decl.source_code) + existing_names = _get_existing_names(original_declarations, analyzer, original_source) + new_declarations = _filter_new_declarations(optimized_declarations, existing_names) if not new_declarations: return original_source - # Sort by line number to maintain order - new_declarations.sort(key=lambda d: d.start_line) - - # Find insertion point (after imports) - lines = original_source.splitlines(keepends=True) - insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source) - - # Build new declarations string - new_decl_code = "\n".join(decl.source_code for decl in new_declarations) - new_decl_code = new_decl_code + "\n\n" + # Build a map of existing declaration names to their end lines (1-indexed) + existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations} - # Insert declarations - before = lines[:insertion_line] - after = lines[insertion_line:] - result_lines = [*before, new_decl_code, *after] + # Insert each new declaration after its dependencies + result = original_source + for decl in new_declarations: + result = _insert_declaration_after_dependencies( + result, decl, existing_decl_end_lines, analyzer, module_abspath + ) + # Update the map with the newly inserted declaration for subsequent insertions + # Re-parse to get accurate line numbers after insertion + updated_declarations = analyzer.find_module_level_declarations(result) + existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations} - return "".join(result_lines) + return result except Exception as e: logger.debug(f"Error adding global declarations: {e}") return original_source -def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int: - """Find the line index where new declarations should be inserted (after imports). +def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]: + """Get all names that already exist in the original source (declarations + imports).""" + existing_names = {decl.name for decl in original_declarations} + + original_imports = analyzer.find_imports(original_source) + for imp in original_imports: + if imp.default_import: + existing_names.add(imp.default_import) + for name, alias in imp.named_imports: + existing_names.add(alias if alias else name) + if imp.namespace_import: + existing_names.add(imp.namespace_import) + + return existing_names + + +def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list: + """Filter declarations to only those that don't exist in the original source.""" + new_declarations = [] + seen_sources: set[str] = set() + + # Sort by line number to maintain order from optimized code + sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line) + + for decl in sorted_declarations: + if decl.name not in existing_names and decl.source_code not in seen_sources: + new_declarations.append(decl) + seen_sources.add(decl.source_code) + + return new_declarations + + +def _insert_declaration_after_dependencies( + source: str, + declaration, + existing_decl_end_lines: dict[str, int], + analyzer: TreeSitterAnalyzer, + module_abspath: Path, +) -> str: + """Insert a declaration after the last existing declaration it depends on. + + Args: + source: Current source code. + declaration: The declaration to insert. + existing_decl_end_lines: Map of existing declaration names to their end lines. + analyzer: TreeSitter analyzer. + module_abspath: Path to the module file. + + Returns: + Source code with the declaration inserted at the correct position. + + """ + # Find identifiers referenced in this declaration + referenced_names = analyzer.find_referenced_identifiers(declaration.source_code) + + # Find the latest end line among all referenced declarations + insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer) + + lines = source.splitlines(keepends=True) + + # Ensure proper spacing + decl_code = declaration.source_code + if not decl_code.endswith("\n"): + decl_code += "\n" + + # Add blank line before if inserting after content + if insertion_line > 0 and lines[insertion_line - 1].strip(): + decl_code = "\n" + decl_code + + before = lines[:insertion_line] + after = lines[insertion_line:] + + return "".join([*before, decl_code, *after]) + + +def _find_insertion_line_for_declaration( + source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer +) -> int: + """Find the line where a declaration should be inserted based on its dependencies. + + Args: + source: Source code. + referenced_names: Names referenced by the declaration. + existing_decl_end_lines: Map of declaration names to their end lines (1-indexed). + analyzer: TreeSitter analyzer. + + Returns: + Line index (0-based) where the declaration should be inserted. + + """ + # Find the maximum end line among referenced declarations + max_dependency_line = 0 + for name in referenced_names: + if name in existing_decl_end_lines: + max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name]) + + if max_dependency_line > 0: + # Insert after the last dependency (end_line is 1-indexed, we need 0-indexed) + return max_dependency_line + + # No dependencies found - insert after imports + lines = source.splitlines(keepends=True) + return _find_line_after_imports(lines, analyzer, source) + + +def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int: + """Find the line index after all imports. Args: lines: Source lines. - analyzer: TreeSitter analyzer for the file. + analyzer: TreeSitter analyzer. source: Full source code. Returns: - Line index (0-based) for insertion. + Line index (0-based) for insertion after imports. """ try: imports = analyzer.find_imports(source) if imports: - # Find the last import's end line return max(imp.end_line for imp in imports) except Exception as exc: - logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}") + logger.debug(f"Exception in _find_line_after_imports: {exc}") - # Default: insert at beginning (after any shebang/directive comments) + # Default: insert at beginning (after shebang/directive comments) for i, line in enumerate(lines): stripped = line.strip() if stripped and not stripped.startswith("//") and not stripped.startswith("#!"): diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index 3d703aa34..992665703 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -1895,6 +1895,96 @@ class DataProcessor { +class TestNewVariableFromOptimizedCode: + """Tests for handling new variables introduced in optimized code.""" + + def test_new_bound_method_variable_added_after_referenced_constant(self, ts_support, temp_project): + """Test that a new variable binding a method is added after the constant it references. + + When optimized code introduces a new module-level variable (like `_has`) that + references an existing constant (like `CODEFLASH_EMPLOYEE_GITHUB_IDS`), the + replacement should: + 1. Add the new variable after the constant it references + 2. Replace the function with the optimized version + """ + from codeflash.models.models import CodeStringsMarkdown, CodeString + + original_source = '''\ +const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([ + "github|1271289", + "github|10488227", + "github|64513301", + "github|106575910", + "github|206515457", + "github|4725571", + "github|235426847", +]); + +export function isCodeflashEmployee(userId: string): boolean { + return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId); +} +''' + file_path = temp_project / "auth.ts" + file_path.write_text(original_source, encoding="utf-8") + + # Optimized code introduces a bound method variable for performance + optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind( + CODEFLASH_EMPLOYEE_GITHUB_IDS +); + +export function isCodeflashEmployee(userId: string): boolean { + return _has(userId); +} +''' + + code_markdown = CodeStringsMarkdown( + code_strings=[ + CodeString( + code=optimized_code, + file_path=Path("auth.ts"), + language="typescript" + ) + ], + language="typescript" + ) + + replaced = replace_function_definitions_for_language( + ["isCodeflashEmployee"], + code_markdown, + file_path, + temp_project, + ) + + assert replaced + result = file_path.read_text() + + # Expected result for strict equality check + expected_result = '''\ +const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([ + "github|1271289", + "github|10488227", + "github|64513301", + "github|106575910", + "github|206515457", + "github|4725571", + "github|235426847", +]); + +const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind( + CODEFLASH_EMPLOYEE_GITHUB_IDS +); + +export function isCodeflashEmployee(userId: string): boolean { + return _has(userId); +} +''' + assert result == expected_result, ( + f"Result does not match expected output.\n" + f"Expected:\n{expected_result}\n\n" + f"Got:\n{result}" + ) + + class TestImportedTypeNotDuplicated: """Tests to ensure imported types are not duplicated during code replacement. From c619d39c2ef197338d5f9dd8d6e52bd8fd44e64c Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 2 Feb 2026 21:38:53 +0200 Subject: [PATCH 2/3] fix a typo --- codeflash/result/create_pr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 8e16e167e..0be4e1cf8 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -281,8 +281,8 @@ def check_create_pr( function_trace_id: str, coverage_message: str, replay_tests: str, - concolic_tests: str, root_dir: Path, + concolic_tests: str = "", git_remote: Optional[str] = None, optimization_review: str = "", original_line_profiler: str | None = None, From 20a32bb714ea2bc639a4876eb58d47b41dcbbee2 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 3 Feb 2026 02:10:35 +0200 Subject: [PATCH 3/3] remove github ids from test --- tests/test_languages/test_js_code_replacer.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index 992665703..7dd206539 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -1911,13 +1911,7 @@ def test_new_bound_method_variable_added_after_referenced_constant(self, ts_supp original_source = '''\ const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([ - "github|1271289", - "github|10488227", - "github|64513301", - "github|106575910", - "github|206515457", - "github|4725571", - "github|235426847", + "1234", ]); export function isCodeflashEmployee(userId: string): boolean { @@ -1961,13 +1955,7 @@ def test_new_bound_method_variable_added_after_referenced_constant(self, ts_supp # Expected result for strict equality check expected_result = '''\ const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([ - "github|1271289", - "github|10488227", - "github|64513301", - "github|106575910", - "github|206515457", - "github|4725571", - "github|235426847", + "1234", ]); const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(