From ddfaed74f9e625e4294b4ab3ca5c0277af6514e2 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Mon, 2 Feb 2026 01:53:02 -0800 Subject: [PATCH 1/5] refactor: use FunctionToOptimize dataclass instead of raw strings in JS optimization Replace raw function_name/qualified_name/class_name string parameters with the FunctionToOptimize dataclass throughout JavaScript optimization code. This provides better type safety and eliminates redundant parameters. Changes: - Add class_name property to FunctionToOptimize (similar to FunctionInfo) - Add from_function_info classmethod for FunctionInfo conversion - Update instrument.py transformers and functions to accept FunctionToOptimize - Update find_references.py to accept FunctionToOptimize - Update support.py and verifier.py callers - Update all test files Co-Authored-By: Claude Opus 4.5 --- codeflash/discovery/functions_to_optimize.py | 29 +++ .../languages/javascript/find_references.py | 93 ++++----- codeflash/languages/javascript/instrument.py | 74 +++----- codeflash/languages/javascript/support.py | 6 +- codeflash/verification/verifier.py | 14 +- tests/test_javascript_assertion_removal.py | 177 ++++++++++-------- tests/test_languages/test_find_references.py | 61 +++--- .../test_javascript_instrumentation.py | 46 +++-- 8 files changed, 266 insertions(+), 234 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index d9fd8e7f9..830380bde 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -39,6 +39,7 @@ from libcst import CSTNode from libcst.metadata import CodeRange + from codeflash.languages.base import FunctionInfo from codeflash.models.models import CodeOptimizationContext from codeflash.verification.verification_utils import TestConfig import contextlib @@ -165,6 +166,14 @@ class FunctionToOptimize: def top_level_parent_name(self) -> str: return self.function_name if not self.parents else self.parents[0].name + @property + def class_name(self) -> str | None: + """Get the immediate parent class name, if any.""" + for parent in reversed(self.parents): + if parent.type == "ClassDef": + return parent.name + return None + def __str__(self) -> str: return ( f"{self.file_path}:{'.'.join([p.name for p in self.parents])}" @@ -182,6 +191,26 @@ def qualified_name(self) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" + @classmethod + def from_function_info(cls, func_info: FunctionInfo) -> FunctionToOptimize: + """Create a FunctionToOptimize from a FunctionInfo instance. + + This enables interoperability between the language-agnostic FunctionInfo + and the FunctionToOptimize dataclass used throughout the codebase. + """ + parents = [FunctionParent(name=p.name, type=p.type) for p in func_info.parents] + return cls( + function_name=func_info.name, + file_path=func_info.file_path, + parents=parents, + starting_line=func_info.start_line, + ending_line=func_info.end_line, + starting_col=func_info.start_col, + ending_col=func_info.end_col, + is_async=func_info.is_async, + language=func_info.language.value, + ) + # ============================================================================= # Multi-language support helpers diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py index aea9b071b..18e096a74 100644 --- a/codeflash/languages/javascript/find_references.py +++ b/codeflash/languages/javascript/find_references.py @@ -21,7 +21,9 @@ if TYPE_CHECKING: from tree_sitter import Node - from codeflash.languages.treesitter_utils import ExportInfo, ImportInfo, TreeSitterAnalyzer + from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize logger = logging.getLogger(__name__) @@ -68,12 +70,16 @@ class ReferenceFinder: Example usage: ```python from codeflash.languages.javascript.find_references import ReferenceFinder + from codeflash.discovery.functions_to_optimize import FunctionToOptimize - finder = ReferenceFinder(project_root=Path("/my/project")) - references = finder.find_references( + func = FunctionToOptimize( function_name="myHelper", - source_file=Path("/my/project/src/utils.ts") + file_path=Path("/my/project/src/utils.ts"), + parents=[], + language="javascript" ) + finder = ReferenceFinder(project_root=Path("/my/project")) + references = finder.find_references(func) for ref in references: print(f"{ref.file_path}:{ref.line} - {ref.context}") ``` @@ -96,21 +102,14 @@ def __init__(self, project_root: Path, exclude_patterns: list[str] | None = None self._file_cache: dict[Path, str] = {} def find_references( - self, - function_name: str, - source_file: Path, - include_definition: bool = False, - max_files: int = 1000, - class_name: str | None = None, + self, function_to_optimize: FunctionToOptimize, include_definition: bool = False, max_files: int = 1000 ) -> list[Reference]: """Find all references to a function across the project. Args: - function_name: Name of the function to find references for. - source_file: Path to the file where the function is defined. + function_to_optimize: The function to find references for. include_definition: Whether to include the function definition itself. max_files: Maximum number of files to search (prevents runaway searches). - class_name: For class methods, the name of the containing class. Returns: List of Reference objects describing each call site. @@ -118,6 +117,9 @@ def find_references( """ from codeflash.languages.treesitter_utils import get_analyzer_for_file + function_name = function_to_optimize.function_name + source_file = function_to_optimize.file_path + references: list[Reference] = [] context = ReferenceSearchContext(max_files=max_files) @@ -128,7 +130,7 @@ def find_references( return references analyzer = get_analyzer_for_file(source_file) - exported = self._analyze_exports(function_name, source_file, source_code, analyzer, class_name) + exported = self._analyze_exports(function_to_optimize, source_file, source_code, analyzer) if not exported: logger.debug("Function %s is not exported from %s", function_name, source_file) @@ -179,9 +181,7 @@ def find_references( # This handles the case where a file re-exports from our source file if file_path not in checked_for_reexports: checked_for_reexports.add(file_path) - reexport_refs = self._find_reexports_direct( - file_path, file_code, source_file, exported, file_analyzer - ) + reexport_refs = self._find_reexports_direct(file_path, file_code, source_file, exported, file_analyzer) references.extend(reexport_refs) # Track re-export files for later searching @@ -192,10 +192,7 @@ def find_references( for reexport_file, reexport_name in reexport_files: # Create a new ExportedFunction for the re-exported function reexported = ExportedFunction( - function_name=reexport_name, - export_name=reexport_name, - is_default=False, - file_path=reexport_file, + function_name=reexport_name, export_name=reexport_name, is_default=False, file_path=reexport_file ) # Search for imports to the re-export file @@ -252,28 +249,24 @@ def find_references( return unique_refs def _analyze_exports( - self, - function_name: str, - file_path: Path, - source_code: str, - analyzer: TreeSitterAnalyzer, - class_name: str | None = None, + self, function_to_optimize: FunctionToOptimize, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer ) -> ExportedFunction | None: """Analyze how a function is exported from its file. For class methods, also checks if the containing class is exported. Args: - function_name: Name of the function to check. + function_to_optimize: The function to check. file_path: Path to the source file. source_code: Source code content. analyzer: TreeSitterAnalyzer instance. - class_name: For class methods, the name of the containing class. Returns: ExportedFunction if the function is exported, None otherwise. """ + function_name = function_to_optimize.function_name + class_name = function_to_optimize.class_name is_exported, export_name = analyzer.is_function_exported(source_code, function_name, class_name) if not is_exported: @@ -287,11 +280,7 @@ def _analyze_exports( ) def _find_matching_import( - self, - imports: list[ImportInfo], - source_file: Path, - importing_file: Path, - exported: ExportedFunction, + self, imports: list[ImportInfo], source_file: Path, importing_file: Path, exported: ExportedFunction ) -> tuple[str, ImportInfo] | None: """Find if any import in a file imports the target function. @@ -379,9 +368,7 @@ def _find_references_in_file( # Handle namespace imports (e.g., "utils.helper") if "." in search_name: namespace, member = search_name.split(".", 1) - self._find_member_calls( - tree.root_node, source_bytes, lines, file_path, namespace, member, references, None - ) + self._find_member_calls(tree.root_node, source_bytes, lines, file_path, namespace, member, references, None) else: # Find direct calls and other reference types self._find_identifier_references( @@ -414,8 +401,6 @@ def _find_identifier_references( current_function: Name of the containing function (for context). """ - from codeflash.languages.treesitter_utils import TreeSitterAnalyzer - # Track current function context new_current_function = current_function if node.type in ("function_declaration", "method_definition"): @@ -435,9 +420,7 @@ def _find_identifier_references( if func_node and func_node.type == "identifier": name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8") if name == search_name: - ref = self._create_reference( - file_path, func_node, lines, "call", search_name, current_function - ) + ref = self._create_reference(file_path, func_node, lines, "call", search_name, current_function) references.append(ref) # Check for identifiers used as callbacks or passed as arguments @@ -448,9 +431,7 @@ def _find_identifier_references( # Determine reference type based on context ref_type = self._determine_reference_type(node, parent, source_bytes) if ref_type: - ref = self._create_reference( - file_path, node, lines, ref_type, search_name, current_function - ) + ref = self._create_reference(file_path, node, lines, ref_type, search_name, current_function) references.append(ref) # Recurse into children @@ -831,22 +812,16 @@ def _read_file(self, file_path: Path) -> str | None: def find_references( - function_name: str, - source_file: Path, - project_root: Path | None = None, - max_files: int = 1000, - class_name: str | None = None, + function_to_optimize: FunctionToOptimize, project_root: Path | None = None, max_files: int = 1000 ) -> list[Reference]: """Convenience function to find all references to a function. This is a simple wrapper around ReferenceFinder for common use cases. Args: - function_name: Name of the function to find references for. - source_file: Path to the file where the function is defined. + function_to_optimize: The function to find references for. project_root: Root directory of the project. If None, uses source_file's parent. max_files: Maximum number of files to search. - class_name: For class methods, the name of the containing class. Returns: List of Reference objects describing each call site. @@ -855,19 +830,19 @@ def find_references( ```python from pathlib import Path from codeflash.languages.javascript.find_references import find_references + from codeflash.discovery.functions_to_optimize import FunctionToOptimize - refs = find_references( - function_name="myHelper", - source_file=Path("/my/project/src/utils.ts"), - project_root=Path("/my/project") + func = FunctionToOptimize( + function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript" ) + refs = find_references(func, project_root=Path("/my/project")) for ref in refs: print(f"{ref.file_path}:{ref.line}:{ref.column} - {ref.reference_type}") ``` """ if project_root is None: - project_root = source_file.parent + project_root = function_to_optimize.file_path.parent finder = ReferenceFinder(project_root) - return finder.find_references(function_name, source_file, max_files=max_files, class_name=class_name) + return finder.find_references(function_to_optimize, max_files=max_files) diff --git a/codeflash/languages/javascript/instrument.py b/codeflash/languages/javascript/instrument.py index 9760de717..d8ddad489 100644 --- a/codeflash/languages/javascript/instrument.py +++ b/codeflash/languages/javascript/instrument.py @@ -15,7 +15,8 @@ if TYPE_CHECKING: from codeflash.code_utils.code_position import CodePosition - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize class TestingMode: @@ -72,15 +73,16 @@ class StandaloneCallTransformer: """ - def __init__(self, func_name: str, qualified_name: str, capture_func: str) -> None: - self.func_name = func_name - self.qualified_name = qualified_name + def __init__(self, function_to_optimize: FunctionToOptimize, capture_func: str) -> None: + self.function_to_optimize = function_to_optimize + self.func_name = function_to_optimize.function_name + self.qualified_name = function_to_optimize.qualified_name self.capture_func = capture_func self.invocation_counter = 0 # Pattern to match func_name( with optional leading await and optional object prefix # Captures: (whitespace)(await )?(object.)*func_name( # We'll filter out expect() and codeflash. cases in the transform loop - self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(") + self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(self.func_name)}\s*\(") def transform(self, code: str) -> str: """Transform all standalone calls in the code.""" @@ -310,7 +312,7 @@ def _generate_transformed_call(self, match: StandaloneCallMatch) -> str: def transform_standalone_calls( - code: str, func_name: str, qualified_name: str, capture_func: str, start_counter: int = 0 + code: str, function_to_optimize: FunctionToOptimize, capture_func: str, start_counter: int = 0 ) -> tuple[str, int]: """Transform standalone func(...) calls in JavaScript test code. @@ -318,8 +320,7 @@ def transform_standalone_calls( Args: code: The test code to transform. - func_name: Name of the function being tested. - qualified_name: Fully qualified function name. + function_to_optimize: The function being tested. capture_func: The capture function to use ('capture' or 'capturePerf'). start_counter: Starting value for the invocation counter. @@ -327,9 +328,7 @@ def transform_standalone_calls( Tuple of (transformed code, final counter value). """ - transformer = StandaloneCallTransformer( - func_name=func_name, qualified_name=qualified_name, capture_func=capture_func - ) + transformer = StandaloneCallTransformer(function_to_optimize=function_to_optimize, capture_func=capture_func) transformer.invocation_counter = start_counter result = transformer.transform(code) return result, transformer.invocation_counter @@ -348,15 +347,18 @@ class ExpectCallTransformer: - Multi-arg assertions: expect(func(args)).toBeCloseTo(0.5, 2) """ - def __init__(self, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False) -> None: - self.func_name = func_name - self.qualified_name = qualified_name + def __init__( + self, function_to_optimize: FunctionToOptimize, capture_func: str, remove_assertions: bool = False + ) -> None: + self.function_to_optimize = function_to_optimize + self.func_name = function_to_optimize.function_name + self.qualified_name = function_to_optimize.qualified_name self.capture_func = capture_func self.remove_assertions = remove_assertions self.invocation_counter = 0 # Pattern to match start of expect((object.)*func_name( # Captures: (whitespace), (object prefix like calc. or this.) - self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(") + self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(self.func_name)}\s*\(") def transform(self, code: str) -> str: """Transform all expect calls in the code.""" @@ -601,7 +603,7 @@ def _generate_transformed_call(self, match: ExpectCallMatch) -> str: def transform_expect_calls( - code: str, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False + code: str, function_to_optimize: FunctionToOptimize, capture_func: str, remove_assertions: bool = False ) -> tuple[str, int]: """Transform expect(func(...)).assertion() calls in JavaScript test code. @@ -609,8 +611,7 @@ def transform_expect_calls( Args: code: The test code to transform. - func_name: Name of the function being tested. - qualified_name: Fully qualified function name. + function_to_optimize: The function being tested. capture_func: The capture function to use ('capture' or 'capturePerf'). remove_assertions: If True, remove assertions entirely (for generated tests). @@ -619,10 +620,7 @@ def transform_expect_calls( """ transformer = ExpectCallTransformer( - func_name=func_name, - qualified_name=qualified_name, - capture_func=capture_func, - remove_assertions=remove_assertions, + function_to_optimize=function_to_optimize, capture_func=capture_func, remove_assertions=remove_assertions ) result = transformer.transform(code) return result, transformer.invocation_counter @@ -658,8 +656,6 @@ def inject_profiling_into_existing_js_test( logger.error(f"Failed to read test file {test_path}: {e}") return False, None - func_name = function_to_optimize.function_name - # Get the relative path for test identification try: rel_path = test_path.relative_to(tests_project_root) @@ -667,14 +663,12 @@ def inject_profiling_into_existing_js_test( rel_path = test_path # Check if the function is imported/required in this test file - if not _is_function_used_in_test(test_code, func_name): - logger.debug(f"Function '{func_name}' not found in test file {test_path}") + if not _is_function_used_in_test(test_code, function_to_optimize.function_name): + logger.debug(f"Function '{function_to_optimize.function_name}' not found in test file {test_path}") return False, None # Instrument the test code - instrumented_code = _instrument_js_test_code( - test_code, func_name, str(rel_path), mode, function_to_optimize.qualified_name - ) + instrumented_code = _instrument_js_test_code(test_code, function_to_optimize, str(rel_path), mode) if instrumented_code == test_code: logger.debug(f"No changes made to test file {test_path}") @@ -716,16 +710,15 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool: def _instrument_js_test_code( - code: str, func_name: str, test_file_path: str, mode: str, qualified_name: str, remove_assertions: bool = False + code: str, function_to_optimize: FunctionToOptimize, test_file_path: str, mode: str, remove_assertions: bool = False ) -> str: """Instrument JavaScript test code with profiling capture calls. Args: code: Original test code. - func_name: Name of the function to instrument. + function_to_optimize: The function to instrument. test_file_path: Relative path to test file. mode: Testing mode (behavior or performance). - qualified_name: Fully qualified function name. remove_assertions: If True, remove expect assertions entirely (for generated/regression tests). If False, keep the expect wrapper (for existing user-written tests). @@ -771,8 +764,7 @@ def _instrument_js_test_code( # Transform expect calls using the refactored transformer code, expect_counter = transform_expect_calls( code=code, - func_name=func_name, - qualified_name=qualified_name, + function_to_optimize=function_to_optimize, capture_func=capture_func, remove_assertions=remove_assertions, ) @@ -780,11 +772,7 @@ def _instrument_js_test_code( # Transform standalone calls (not inside expect wrappers) # Continue counter from expect transformer to ensure unique IDs code, _final_counter = transform_standalone_calls( - code=code, - func_name=func_name, - qualified_name=qualified_name, - capture_func=capture_func, - start_counter=expect_counter, + code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=expect_counter ) return code @@ -941,7 +929,7 @@ def get_instrumented_test_path(original_path: Path, mode: str) -> Path: def instrument_generated_js_test( - test_code: str, function_name: str, qualified_name: str, mode: str = TestingMode.BEHAVIOR + test_code: str, function_to_optimize: FunctionToOptimize, mode: str = TestingMode.BEHAVIOR ) -> str: """Instrument generated JavaScript/TypeScript test code. @@ -956,8 +944,7 @@ def instrument_generated_js_test( Args: test_code: The generated test code to instrument. - function_name: Name of the function being tested. - qualified_name: Fully qualified function name (e.g., 'module.funcName'). + function_to_optimize: The function being tested. mode: Testing mode - "behavior" or "performance". Returns: @@ -971,9 +958,8 @@ def instrument_generated_js_test( # Generated tests are treated as regression tests, so we remove LLM-generated assertions return _instrument_js_test_code( code=test_code, - func_name=function_name, + function_to_optimize=function_to_optimize, test_file_path="generated_test", mode=mode, - qualified_name=qualified_name, remove_assertions=True, ) diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 0e77c2c5e..8cdd54e47 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -983,14 +983,14 @@ def find_references( List of ReferenceInfo objects describing each reference location. """ + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import ReferenceInfo from codeflash.languages.javascript.find_references import ReferenceFinder try: finder = ReferenceFinder(project_root) - refs = finder.find_references( - function.name, function.file_path, max_files=max_files, class_name=function.class_name - ) + func_to_optimize = FunctionToOptimize.from_function_info(function) + refs = finder.find_references(func_to_optimize, max_files=max_files) # Convert to ReferenceInfo and filter out tests result: list[ReferenceInfo] = [] diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 19500b968..c53f71cd5 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -72,11 +72,11 @@ def generate_tests( from codeflash.languages.javascript.module_system import ensure_module_system_compatibility source_file = Path(function_to_optimize.file_path) - func_name = function_to_optimize.function_name - qualified_name = function_to_optimize.qualified_name # Validate and fix import styles (default vs named exports) - generated_test_source = validate_and_fix_import_style(generated_test_source, source_file, func_name) + generated_test_source = validate_and_fix_import_style( + generated_test_source, source_file, function_to_optimize.function_name + ) # Convert module system if needed (e.g., CommonJS -> ESM for ESM projects) generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system) @@ -84,20 +84,18 @@ def generate_tests( # Instrument for behavior verification (writes to SQLite) instrumented_behavior_test_source = instrument_generated_js_test( test_code=generated_test_source, - function_name=func_name, - qualified_name=qualified_name, + function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR, ) # Instrument for performance measurement (prints to stdout) instrumented_perf_test_source = instrument_generated_js_test( test_code=generated_test_source, - function_name=func_name, - qualified_name=qualified_name, + function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE, ) - logger.debug(f"Instrumented JS/TS tests locally for {func_name}") + logger.debug(f"Instrumented JS/TS tests locally for {function_to_optimize.function_name}") else: # Python: instrumentation is done by aiservice, just replace temp dir placeholders instrumented_behavior_test_source = instrumented_behavior_test_source.replace( diff --git a/tests/test_javascript_assertion_removal.py b/tests/test_javascript_assertion_removal.py index ac1a34cbe..e0ee483e8 100644 --- a/tests/test_javascript_assertion_removal.py +++ b/tests/test_javascript_assertion_removal.py @@ -6,7 +6,22 @@ from __future__ import annotations +from pathlib import Path + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.javascript.instrument import TestingMode, instrument_generated_js_test, transform_expect_calls +from codeflash.models.models import FunctionParent + + +def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize: + """Helper to create FunctionToOptimize for testing.""" + parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else [] + return FunctionToOptimize( + function_name=name, + file_path=Path("/test/file.js"), + parents=parents, + language="javascript", + ) class TestExpectCallTransformer: @@ -15,139 +30,139 @@ class TestExpectCallTransformer: def test_basic_toBe_assertion(self) -> None: """Test basic .toBe() assertion removal.""" code = "expect(fibonacci(5)).toBe(5);" - result, _ = transform_expect_calls(code, "fibonacci", "fibonacci", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("fibonacci"), "capture", remove_assertions=True) assert result == "codeflash.capture('fibonacci', '1', fibonacci, 5);" def test_basic_toEqual_assertion(self) -> None: """Test .toEqual() assertion removal.""" code = "expect(func([1, 2, 3])).toEqual([1, 2, 3]);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);" def test_toStrictEqual_assertion(self) -> None: """Test .toStrictEqual() assertion removal.""" code = "expect(func({a: 1})).toStrictEqual({a: 1});" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, {a: 1});" def test_toBeCloseTo_with_precision(self) -> None: """Test .toBeCloseTo() with precision argument.""" code = "expect(func(3.14159)).toBeCloseTo(3.14, 2);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 3.14159);" def test_toBeTruthy_no_args(self) -> None: """Test .toBeTruthy() assertion without arguments.""" code = "expect(func(true)).toBeTruthy();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, true);" def test_toBeFalsy_no_args(self) -> None: """Test .toBeFalsy() assertion without arguments.""" code = "expect(func(0)).toBeFalsy();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 0);" def test_toBeNull(self) -> None: """Test .toBeNull() assertion.""" code = "expect(func(null)).toBeNull();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, null);" def test_toBeUndefined(self) -> None: """Test .toBeUndefined() assertion.""" code = "expect(func()).toBeUndefined();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func);" def test_toBeDefined(self) -> None: """Test .toBeDefined() assertion.""" code = "expect(func(1)).toBeDefined();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 1);" def test_toBeNaN(self) -> None: """Test .toBeNaN() assertion.""" code = "expect(func(NaN)).toBeNaN();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, NaN);" def test_toBeGreaterThan(self) -> None: """Test .toBeGreaterThan() assertion.""" code = "expect(func(10)).toBeGreaterThan(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 10);" def test_toBeLessThan(self) -> None: """Test .toBeLessThan() assertion.""" code = "expect(func(3)).toBeLessThan(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 3);" def test_toBeGreaterThanOrEqual(self) -> None: """Test .toBeGreaterThanOrEqual() assertion.""" code = "expect(func(5)).toBeGreaterThanOrEqual(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_toBeLessThanOrEqual(self) -> None: """Test .toBeLessThanOrEqual() assertion.""" code = "expect(func(5)).toBeLessThanOrEqual(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_toContain(self) -> None: """Test .toContain() assertion.""" code = "expect(func([1, 2, 3])).toContain(2);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);" def test_toContainEqual(self) -> None: """Test .toContainEqual() assertion.""" code = "expect(func([{a: 1}])).toContainEqual({a: 1});" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, [{a: 1}]);" def test_toHaveLength(self) -> None: """Test .toHaveLength() assertion.""" code = "expect(func([1, 2, 3])).toHaveLength(3);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);" def test_toMatch_string(self) -> None: """Test .toMatch() with string pattern.""" code = "expect(func('hello')).toMatch('ell');" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 'hello');" def test_toMatch_regex(self) -> None: """Test .toMatch() with regex pattern.""" code = "expect(func('hello')).toMatch(/ell/);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 'hello');" def test_toMatchObject(self) -> None: """Test .toMatchObject() assertion.""" code = "expect(func({a: 1, b: 2})).toMatchObject({a: 1});" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, {a: 1, b: 2});" def test_toHaveProperty(self) -> None: """Test .toHaveProperty() assertion.""" code = "expect(func({a: 1})).toHaveProperty('a');" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, {a: 1});" def test_toHaveProperty_with_value(self) -> None: """Test .toHaveProperty() with value.""" code = "expect(func({a: 1})).toHaveProperty('a', 1);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, {a: 1});" def test_toBeInstanceOf(self) -> None: """Test .toBeInstanceOf() assertion.""" code = "expect(func()).toBeInstanceOf(Array);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func);" @@ -157,31 +172,31 @@ class TestNegatedAssertions: def test_not_toBe(self) -> None: """Test .not.toBe() assertion removal.""" code = "expect(func(5)).not.toBe(10);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_not_toEqual(self) -> None: """Test .not.toEqual() assertion removal.""" code = "expect(func([1, 2])).not.toEqual([3, 4]);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, [1, 2]);" def test_not_toBeTruthy(self) -> None: """Test .not.toBeTruthy() assertion removal.""" code = "expect(func(0)).not.toBeTruthy();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 0);" def test_not_toContain(self) -> None: """Test .not.toContain() assertion removal.""" code = "expect(func([1, 2, 3])).not.toContain(4);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);" def test_not_toBeNull(self) -> None: """Test .not.toBeNull() assertion removal.""" code = "expect(func(1)).not.toBeNull();" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 1);" @@ -191,31 +206,31 @@ class TestAsyncAssertions: def test_resolves_toBe(self) -> None: """Test .resolves.toBe() assertion removal.""" code = "expect(asyncFunc(5)).resolves.toBe(10);" - result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True) assert result == "codeflash.capture('asyncFunc', '1', asyncFunc, 5);" def test_resolves_toEqual(self) -> None: """Test .resolves.toEqual() assertion removal.""" code = "expect(asyncFunc()).resolves.toEqual({data: 'test'});" - result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True) assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);" def test_rejects_toThrow(self) -> None: """Test .rejects.toThrow() assertion removal.""" code = "expect(asyncFunc()).rejects.toThrow();" - result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True) assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);" def test_rejects_toThrow_with_message(self) -> None: """Test .rejects.toThrow() with error message.""" code = "expect(asyncFunc()).rejects.toThrow('Error message');" - result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True) assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);" def test_not_resolves_toBe(self) -> None: """Test .not.resolves.toBe() (rare but valid).""" code = "expect(asyncFunc()).not.resolves.toBe(5);" - result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True) assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);" @@ -225,31 +240,31 @@ class TestNestedParentheses: def test_nested_function_call(self) -> None: """Test nested function call in arguments.""" code = "expect(func(getN(5))).toBe(10);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, getN(5));" def test_deeply_nested_calls(self) -> None: """Test deeply nested function calls.""" code = "expect(func(outer(inner(deep(1))))).toBe(100);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, outer(inner(deep(1))));" def test_multiple_nested_args(self) -> None: """Test multiple arguments with nested calls.""" code = "expect(func(getA(), getB(getC()))).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, getA(), getB(getC()));" def test_object_with_nested_calls(self) -> None: """Test object argument with nested function calls.""" code = "expect(func({key: getValue()})).toEqual({key: 1});" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, {key: getValue()});" def test_array_with_nested_calls(self) -> None: """Test array argument with nested function calls.""" code = "expect(func([getA(), getB()])).toEqual([1, 2]);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, [getA(), getB()]);" @@ -259,31 +274,31 @@ class TestStringLiterals: def test_string_with_parentheses(self) -> None: """Test string argument containing parentheses.""" code = "expect(func('hello (world)')).toBe('result');" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 'hello (world)');" def test_double_quoted_string_with_parens(self) -> None: """Test double-quoted string with parentheses.""" code = 'expect(func("hello (world)")).toBe("result");' - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, \"hello (world)\");" def test_template_literal(self) -> None: """Test template literal argument.""" code = "expect(func(`template ${value}`)).toBe('result');" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, `template ${value}`);" def test_template_literal_with_parens(self) -> None: """Test template literal with parentheses inside.""" code = "expect(func(`hello (${name})`)).toBe('greeting');" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, `hello (${name})`);" def test_escaped_quotes(self) -> None: """Test string with escaped quotes.""" code = "expect(func('it\\'s working')).toBe('yes');" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 'it\\'s working');" @@ -293,39 +308,39 @@ class TestWhitespaceHandling: def test_leading_whitespace_preserved(self) -> None: """Test that leading whitespace is preserved.""" code = " expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == " codeflash.capture('func', '1', func, 5);" def test_tab_indentation(self) -> None: """Test tab indentation is preserved.""" code = "\t\texpect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "\t\tcodeflash.capture('func', '1', func, 5);" def test_no_space_after_expect(self) -> None: """Test expect without space before parenthesis.""" code = "expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_space_after_expect(self) -> None: """Test expect with space before parenthesis.""" code = "expect (func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_newline_in_assertion(self) -> None: """Test assertion split across lines.""" code = """expect(func(5)) .toBe(5);""" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_newline_after_expect_close(self) -> None: """Test newline after expect closing paren.""" code = """expect(func(5)) .toBe(5);""" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" @@ -337,7 +352,7 @@ def test_multiple_assertions_same_function(self) -> None: code = """expect(func(1)).toBe(1); expect(func(2)).toBe(2); expect(func(3)).toBe(3);""" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) expected = """codeflash.capture('func', '1', func, 1); codeflash.capture('func', '2', func, 2); codeflash.capture('func', '3', func, 3);""" @@ -348,7 +363,7 @@ def test_multiple_different_assertions(self) -> None: code = """expect(func(1)).toBe(1); expect(func(2)).toEqual(2); expect(func(3)).not.toBe(0);""" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) expected = """codeflash.capture('func', '1', func, 1); codeflash.capture('func', '2', func, 2); codeflash.capture('func', '3', func, 3);""" @@ -360,7 +375,7 @@ def test_mixed_with_other_code(self) -> None: expect(func(x)).toBe(10); console.log('done'); expect(func(x + 1)).toBe(12);""" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) expected = """const x = 5; codeflash.capture('func', '1', func, x); console.log('done'); @@ -374,20 +389,20 @@ class TestSemicolonHandling: def test_with_semicolon(self) -> None: """Test assertion with trailing semicolon.""" code = "expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_without_semicolon(self) -> None: """Test assertion without trailing semicolon.""" code = "expect(func(5)).toBe(5)" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func, 5);" def test_multiple_without_semicolons(self) -> None: """Test multiple assertions without semicolons (common in some styles).""" code = """expect(func(1)).toBe(1) expect(func(2)).toBe(2)""" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) expected = """codeflash.capture('func', '1', func, 1); codeflash.capture('func', '2', func, 2);""" assert result == expected @@ -399,25 +414,25 @@ class TestPreservingAssertions: def test_preserve_toBe(self) -> None: """Test preserving .toBe() assertion.""" code = "expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False) assert result == "expect(codeflash.capture('func', '1', func, 5)).toBe(5);" def test_preserve_not_toBe(self) -> None: """Test preserving .not.toBe() assertion.""" code = "expect(func(5)).not.toBe(10);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False) assert result == "expect(codeflash.capture('func', '1', func, 5)).not.toBe(10);" def test_preserve_resolves(self) -> None: """Test preserving .resolves assertion.""" code = "expect(asyncFunc(5)).resolves.toBe(10);" - result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=False) + result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=False) assert result == "expect(codeflash.capture('asyncFunc', '1', asyncFunc, 5)).resolves.toBe(10);" def test_preserve_toBeCloseTo(self) -> None: """Test preserving .toBeCloseTo() with args.""" code = "expect(func(3.14159)).toBeCloseTo(3.14, 2);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False) assert result == "expect(codeflash.capture('func', '1', func, 3.14159)).toBeCloseTo(3.14, 2);" @@ -427,13 +442,13 @@ class TestCaptureFunction: def test_behavior_mode_uses_capture(self) -> None: """Test behavior mode uses capture function.""" code = "expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert "codeflash.capture(" in result def test_performance_mode_uses_capturePerf(self) -> None: """Test performance mode uses capturePerf function.""" code = "expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capturePerf", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capturePerf", remove_assertions=True) assert "codeflash.capturePerf(" in result @@ -443,13 +458,19 @@ class TestQualifiedNames: def test_simple_qualified_name(self) -> None: """Test simple qualified name.""" code = "expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "module.func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func", class_name="module"), "capture", remove_assertions=True) assert result == "codeflash.capture('module.func', '1', func, 5);" def test_nested_qualified_name(self) -> None: """Test nested qualified name.""" code = "expect(func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "pkg.module.func", "capture", remove_assertions=True) + func = FunctionToOptimize( + function_name="func", + file_path=Path("/test/file.js"), + parents=[FunctionParent(name="pkg", type="ClassDef"), FunctionParent(name="module", type="ClassDef")], + language="javascript", + ) + result, _ = transform_expect_calls(code, func, "capture", remove_assertions=True) assert result == "codeflash.capture('pkg.module.func', '1', func, 5);" @@ -459,7 +480,7 @@ class TestEdgeCases: def test_function_name_as_substring(self) -> None: """Test that function name matching is exact.""" code = "expect(myFunc(5)).toBe(5); expect(func(10)).toBe(10);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) # Should only transform func, not myFunc assert "expect(myFunc(5)).toBe(5)" in result assert "codeflash.capture('func', '1', func, 10)" in result @@ -467,26 +488,26 @@ def test_function_name_as_substring(self) -> None: def test_empty_args(self) -> None: """Test function call with no arguments.""" code = "expect(func()).toBe(undefined);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == "codeflash.capture('func', '1', func);" def test_object_method_style(self) -> None: """Test that method calls on objects are not matched.""" code = "expect(obj.func(5)).toBe(5);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) # Should not transform method calls assert result == "expect(obj.func(5)).toBe(5);" def test_non_matching_code_unchanged(self) -> None: """Test that non-matching code remains unchanged.""" code = "const x = func(5); console.log(x);" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) assert result == code def test_expect_without_assertion(self) -> None: """Test expect without assertion is not transformed.""" code = "const result = expect(func(5));" - result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True) + result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True) # Should not transform as there's no assertion assert result == code @@ -504,7 +525,7 @@ def test_full_test_file_behavior_mode(self) -> None: expect(fibonacci(10)).toBe(55); }); });""" - result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.BEHAVIOR) assert "import codeflash from 'codeflash'" in result assert "codeflash.capture('fibonacci'" in result assert ".toBe(" not in result @@ -518,7 +539,7 @@ def test_full_test_file_performance_mode(self) -> None: expect(fibonacci(5)).toBe(5); }); });""" - result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.PERFORMANCE) + result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.PERFORMANCE) assert "import codeflash from 'codeflash'" in result assert "codeflash.capturePerf('fibonacci'" in result assert ".toBe(" not in result @@ -532,7 +553,7 @@ def test_commonjs_import_style(self) -> None: expect(fibonacci(5)).toBe(5); }); });""" - result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.BEHAVIOR) assert "const codeflash = require('codeflash')" in result assert "codeflash.capture('fibonacci'" in result @@ -549,7 +570,7 @@ def test_various_assertion_types(self) -> None: expect(func(null)).toBeNull(); }); });""" - result = instrument_generated_js_test(code, "func", "func", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(code, make_func("func"), TestingMode.BEHAVIOR) # All assertions should be removed assert ".toBe(" not in result assert ".not." not in result @@ -561,12 +582,12 @@ def test_various_assertion_types(self) -> None: def test_empty_code(self) -> None: """Test with empty code.""" - result = instrument_generated_js_test("", "func", "func", TestingMode.BEHAVIOR) + result = instrument_generated_js_test("", make_func("func"), TestingMode.BEHAVIOR) assert result == "" def test_whitespace_only_code(self) -> None: """Test with whitespace-only code.""" - result = instrument_generated_js_test(" \n\t ", "func", "func", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(" \n\t ", make_func("func"), TestingMode.BEHAVIOR) assert result == " \n\t " @@ -594,7 +615,7 @@ def test_jest_describe_test_structure(self) -> None: }); }); });""" - result = instrument_generated_js_test(code, "processData", "processData", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(code, make_func("processData"), TestingMode.BEHAVIOR) assert result.count("codeflash.capture(") == 3 assert "toEqual(" not in result assert "toBeNull(" not in result @@ -612,7 +633,7 @@ def test_vitest_it_structure(self) -> None: expect(calculate(2, 3, 'mul')).toBe(6); }); });""" - result = instrument_generated_js_test(code, "calculate", "calculate", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(code, make_func("calculate"), TestingMode.BEHAVIOR) assert result.count("codeflash.capture(") == 2 assert ".toBe(" not in result @@ -629,7 +650,7 @@ def test_async_await_pattern(self) -> None: expect(fetchData('/invalid')).rejects.toThrow('Not found'); }); });""" - result = instrument_generated_js_test(code, "fetchData", "fetchData", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(code, make_func("fetchData"), TestingMode.BEHAVIOR) assert result.count("codeflash.capture(") == 2 assert ".resolves." not in result assert ".rejects." not in result @@ -647,6 +668,6 @@ def test_numeric_precision_tests(self) -> None: expect(calculatePi(5)).toBeCloseTo(3.14159, 5); }); });""" - result = instrument_generated_js_test(code, "calculatePi", "calculatePi", TestingMode.BEHAVIOR) + result = instrument_generated_js_test(code, make_func("calculatePi"), TestingMode.BEHAVIOR) assert result.count("codeflash.capture(") == 2 assert ".toBeCloseTo(" not in result diff --git a/tests/test_languages/test_find_references.py b/tests/test_languages/test_find_references.py index 8bbd5ce09..970ff05d9 100644 --- a/tests/test_languages/test_find_references.py +++ b/tests/test_languages/test_find_references.py @@ -11,6 +11,7 @@ import pytest from pathlib import Path +from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.javascript.find_references import ( Reference, ReferenceFinder, @@ -20,6 +21,18 @@ ) from codeflash.languages.base import Language, FunctionInfo, ReferenceInfo from codeflash.code_utils.code_extractor import _format_references_as_markdown +from codeflash.models.models import FunctionParent + + +def make_func(name: str, file_path: Path, class_name: str | None = None) -> FunctionToOptimize: + """Helper to create FunctionToOptimize for testing.""" + parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else [] + return FunctionToOptimize( + function_name=name, + file_path=file_path, + parents=parents, + language="javascript", + ) class TestReferenceFinder: @@ -111,7 +124,7 @@ def test_find_named_export_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils" / "DynamicBindingUtils.ts" - refs = finder.find_references("getDynamicBindings", source_file) + refs = finder.find_references(make_func("getDynamicBindings", source_file)) # Sort refs by file path for consistent ordering refs_sorted = sorted(refs, key=lambda r: (str(r.file_path), r.line)) @@ -140,7 +153,7 @@ def test_format_references_as_markdown_named_exports(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils" / "DynamicBindingUtils.ts" - refs = finder.find_references("getDynamicBindings", source_file) + refs = finder.find_references(make_func("getDynamicBindings", source_file)) # Convert to ReferenceInfo and sort for consistent ordering ref_infos = sorted([ @@ -221,7 +234,7 @@ def test_find_default_export_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "helper.ts" - refs = finder.find_references("processData", source_file) + refs = finder.find_references(make_func("processData", source_file)) # Should find references in both files ref_files = {str(ref.file_path) for ref in refs} @@ -247,7 +260,7 @@ def test_format_references_as_markdown_default_exports(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "helper.ts" - refs = finder.find_references("processData", source_file) + refs = finder.find_references(make_func("processData", source_file)) ref_infos = sorted([ ReferenceInfo( file_path=r.file_path, line=r.line, column=r.column, @@ -315,7 +328,7 @@ def test_find_reexport_reference_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils" / "filterUtils.ts" - refs = finder.find_references("filterBySearchTerm", source_file) + refs = finder.find_references(make_func("filterBySearchTerm", source_file)) # Should find re-export in index.ts reexport_refs = [r for r in refs if r.reference_type == "reexport"] @@ -336,7 +349,7 @@ def test_format_references_as_markdown_reexports(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils" / "filterUtils.ts" - refs = finder.find_references("filterBySearchTerm", source_file) + refs = finder.find_references(make_func("filterBySearchTerm", source_file)) ref_infos = sorted([ ReferenceInfo( file_path=r.file_path, line=r.line, column=r.column, @@ -395,7 +408,7 @@ def test_find_callback_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "transforms.ts" - refs = finder.find_references("normalizeItem", source_file) + refs = finder.find_references(make_func("normalizeItem", source_file)) # Should find the callback reference callback_refs = [r for r in refs if r.reference_type == "callback"] @@ -412,7 +425,7 @@ def test_format_references_as_markdown_callbacks(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "transforms.ts" - refs = finder.find_references("normalizeItem", source_file) + refs = finder.find_references(make_func("normalizeItem", source_file)) ref_infos = [ ReferenceInfo( file_path=r.file_path, line=r.line, column=r.column, @@ -469,7 +482,7 @@ def test_find_aliased_import_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils.ts" - refs = finder.find_references("computeValue", source_file) + refs = finder.find_references(make_func("computeValue", source_file)) # Should find the reference even though it's called as "calculate" assert len(refs) == 1 @@ -486,7 +499,7 @@ def test_format_references_as_markdown_aliases(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils.ts" - refs = finder.find_references("computeValue", source_file) + refs = finder.find_references(make_func("computeValue", source_file)) ref_infos = [ ReferenceInfo( file_path=r.file_path, line=r.line, column=r.column, @@ -542,7 +555,7 @@ def test_find_namespace_import_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "mathUtils.ts" - refs = finder.find_references("add", source_file) + refs = finder.find_references(make_func("add", source_file)) assert len(refs) == 1 ref = refs[0] @@ -558,7 +571,7 @@ def test_format_references_as_markdown_namespace(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "mathUtils.ts" - refs = finder.find_references("add", source_file) + refs = finder.find_references(make_func("add", source_file)) ref_infos = [ ReferenceInfo( file_path=r.file_path, line=r.line, column=r.column, @@ -616,7 +629,7 @@ def test_find_memoized_function_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "expensive.ts" - refs = finder.find_references("computeExpensive", source_file) + refs = finder.find_references(make_func("computeExpensive", source_file)) # Should find memoize call and direct call assert len(refs) >= 2 @@ -657,7 +670,7 @@ def test_find_recursive_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "recursive.ts" - refs = finder.find_references("factorial", source_file, include_definition=True) + refs = finder.find_references(make_func("factorial", source_file), include_definition=True) # Should find the recursive call call_refs = [r for r in refs if r.reference_type == "call"] @@ -709,7 +722,7 @@ def test_find_all_references_across_codebase_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils" / "widgetUtils.ts" - refs = finder.find_references("isLargeWidget", source_file) + refs = finder.find_references(make_func("isLargeWidget", source_file)) # Should find re-export in index.ts reexport_refs = [r for r in refs if r.reference_type == "reexport"] @@ -729,7 +742,7 @@ def test_format_references_as_markdown_complex(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "utils" / "widgetUtils.ts" - refs = finder.find_references("isLargeWidget", source_file) + refs = finder.find_references(make_func("isLargeWidget", source_file)) ref_infos = sorted([ ReferenceInfo( file_path=r.file_path, line=r.line, column=r.column, @@ -771,7 +784,7 @@ def test_nonexistent_file(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "nonexistent.ts" - refs = finder.find_references("someFunction", source_file) + refs = finder.find_references(make_func("someFunction", source_file)) assert refs == [] @@ -791,7 +804,7 @@ def test_non_exported_function(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "private.ts" - refs = finder.find_references("internalHelper", source_file) + refs = finder.find_references(make_func("internalHelper", source_file)) # Should only find internal reference assert all(r.file_path == source_file for r in refs) @@ -803,7 +816,7 @@ def test_empty_file(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "empty.ts" - refs = finder.find_references("anything", source_file) + refs = finder.find_references(make_func("anything", source_file)) assert refs == [] @@ -849,7 +862,7 @@ def test_find_commonjs_references_values(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "helpers.js" - refs = finder.find_references("processConfig", source_file) + refs = finder.find_references(make_func("processConfig", source_file)) assert len(refs) >= 1 main_ref = next((r for r in refs if "main.js" in str(r.file_path)), None) @@ -863,7 +876,7 @@ def test_format_references_as_markdown_commonjs(self, project_root): finder = ReferenceFinder(project_root) source_file = project_root / "src" / "helpers.js" - refs = finder.find_references("processConfig", source_file) + refs = finder.find_references(make_func("processConfig", source_file)) ref_infos = sorted([ ReferenceInfo( file_path=r.file_path, line=r.line, column=r.column, @@ -915,7 +928,7 @@ def test_find_references_function_values(self, project_root): """Test the find_references convenience function with exact values.""" source_file = project_root / "src" / "utils.ts" - refs = find_references("helper", source_file, project_root=project_root) + refs = find_references(make_func("helper", source_file), project_root=project_root) assert len(refs) == 1 ref = refs[0] @@ -1054,7 +1067,7 @@ def test_circular_import_handling(self, project_root): source_file = project_root / "src" / "a.ts" # Should not hang or crash - refs = finder.find_references("funcA", source_file) + refs = finder.find_references(make_func("funcA", source_file)) # Should find reference in b.ts b_refs = [r for r in refs if "b.ts" in str(r.file_path)] @@ -1084,5 +1097,5 @@ def test_syntax_error_graceful_handling(self, project_root): source_file = project_root / "src" / "valid.ts" # Should not crash - refs = finder.find_references("validFunction", source_file) + refs = finder.find_references(make_func("validFunction", source_file)) assert isinstance(refs, list) diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index e46c57ec7..423504f66 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -6,9 +6,22 @@ import tempfile from pathlib import Path +from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import FunctionInfo, Language from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler from codeflash.languages.javascript.tracer import JavaScriptTracer +from codeflash.models.models import FunctionParent + + +def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize: + """Helper to create FunctionToOptimize for testing.""" + parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else [] + return FunctionToOptimize( + function_name=name, + file_path=Path("/test/file.js"), + parents=parents, + language="javascript", + ) class TestJavaScriptLineProfiler: @@ -352,7 +365,7 @@ def test_instrument_method_call_on_instance(self): console.log(result); """ transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture" ) # Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10) @@ -371,7 +384,7 @@ def test_instrument_expect_with_method_call(self): }); """ transformed, counter = transform_expect_calls( - code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture" ) # Should transform expect(calc.fibonacci(10)) to @@ -393,8 +406,7 @@ def test_instrument_expect_with_method_removes_assertion(self): """ transformed, counter = transform_expect_calls( code=code, - func_name="fibonacci", - qualified_name="FibonacciCalculator.fibonacci", + function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture", remove_assertions=True, ) @@ -419,7 +431,7 @@ class FibonacciCalculator { } """ transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture" ) # The method definition should NOT be transformed @@ -438,7 +450,7 @@ def test_does_not_instrument_prototype_assignment(self): }; """ transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture" ) # The prototype assignment should NOT be transformed @@ -456,7 +468,7 @@ def test_instrument_multiple_method_calls(self): const sum = a + b; """ transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture" ) # Should transform both calls @@ -475,7 +487,7 @@ class Wrapper { } """ transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="Wrapper.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="Wrapper"), capture_func="capture" ) # Should transform this.fibonacci(n) @@ -515,10 +527,9 @@ def test_full_instrumentation_produces_valid_syntax(self): """ instrumented = _instrument_js_test_code( code=test_code, - func_name="fibonacci", + function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), test_file_path="test.js", mode="behavior", - qualified_name="FibonacciCalculator.fibonacci", ) # Check that codeflash import was added @@ -545,7 +556,7 @@ def test_instrumentation_preserves_test_structure(self): }); """ instrumented = _instrument_js_test_code( - code=test_code, func_name="add", test_file_path="test.js", mode="behavior", qualified_name="Calculator.add" + code=test_code, function_to_optimize=make_func("add", class_name="Calculator"), test_file_path="test.js", mode="behavior" ) # describe and test structure should be preserved @@ -567,7 +578,7 @@ def test_instrumentation_with_async_methods(self): console.log(data); """ transformed, counter = transform_standalone_calls( - code=code, func_name="fetchData", qualified_name="ApiClient.fetchData", capture_func="capture" + code=code, function_to_optimize=make_func("fetchData", class_name="ApiClient"), capture_func="capture" ) # Should preserve await @@ -586,7 +597,7 @@ def test_standalone_method_call_exact_output(self): code = " calc.fibonacci(10);" transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture" ) expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);" @@ -600,7 +611,7 @@ def test_expect_method_call_exact_output(self): code = " expect(calc.fibonacci(10)).toBe(55);" transformed, counter = transform_expect_calls( - code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture" ) expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);" @@ -615,8 +626,7 @@ def test_expect_method_call_remove_assertions_exact_output(self): transformed, counter = transform_expect_calls( code=code, - func_name="fibonacci", - qualified_name="Calculator.fibonacci", + function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture", remove_assertions=True, ) @@ -632,7 +642,7 @@ def test_standalone_function_call_no_object_prefix(self): code = " fibonacci(10);" transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci"), capture_func="capture" ) expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);" @@ -646,7 +656,7 @@ def test_this_method_call_exact_output(self): code = " return this.fibonacci(n - 1);" transformed, counter = transform_standalone_calls( - code=code, func_name="fibonacci", qualified_name="Class.fibonacci", capture_func="capture" + code=code, function_to_optimize=make_func("fibonacci", class_name="Class"), capture_func="capture" ) expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);" From 1b1d5a2460aa4aeeca681d69e0512132b3ac9b69 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Mon, 2 Feb 2026 09:36:24 -0800 Subject: [PATCH 2/5] refactor: consolidate FunctionInfo and FunctionToOptimize - Make FunctionToOptimize the canonical dataclass for functions across all languages - FunctionInfo is now a lazy alias for FunctionToOptimize for backward compatibility - Move Language enum to separate module to avoid circular imports - Update all test files to use new field names (function_name, starting_line, ending_line) - Add __str__ method to FunctionParent for compatibility - Add default value for parents field in FunctionToOptimize Co-Authored-By: Claude Opus 4.5 --- codeflash/cli_cmds/init_javascript.py | 2 +- codeflash/code_utils/code_extractor.py | 34 ++-- codeflash/code_utils/code_replacer.py | 16 +- codeflash/code_utils/env_utils.py | 5 +- codeflash/code_utils/formatter.py | 7 +- codeflash/context/code_context_extractor.py | 16 +- codeflash/discovery/discover_unit_tests.py | 29 +-- codeflash/discovery/functions_to_optimize.py | 37 ++-- codeflash/languages/__init__.py | 12 +- codeflash/languages/base.py | 139 ++++---------- .../languages/javascript/find_references.py | 5 +- .../languages/javascript/import_resolver.py | 17 +- .../languages/javascript/line_profiler.py | 6 +- codeflash/languages/javascript/support.py | 173 +++++++++--------- codeflash/languages/javascript/tracer.py | 6 +- codeflash/languages/language_enum.py | 17 ++ codeflash/languages/python/support.py | 70 ++++--- codeflash/languages/registry.py | 2 +- codeflash/languages/treesitter_utils.py | 1 + codeflash/models/models.py | 3 + codeflash/optimization/function_optimizer.py | 20 +- codeflash/verification/verifier.py | 8 +- tests/test_languages/test_base.py | 100 +++++----- tests/test_languages/test_javascript_e2e.py | 4 +- .../test_javascript_instrumentation.py | 8 +- .../test_languages/test_javascript_support.py | 122 ++++++------ tests/test_languages/test_language_parity.py | 114 ++++++------ tests/test_languages/test_python_support.py | 54 +++--- tests/test_languages/test_treesitter_utils.py | 24 +-- .../test_typescript_code_extraction.py | 6 +- tests/test_languages/test_vitest_e2e.py | 2 +- 31 files changed, 490 insertions(+), 569 deletions(-) create mode 100644 codeflash/languages/language_enum.py diff --git a/codeflash/cli_cmds/init_javascript.py b/codeflash/cli_cmds/init_javascript.py index c9546d1f3..05159f65e 100644 --- a/codeflash/cli_cmds/init_javascript.py +++ b/codeflash/cli_cmds/init_javascript.py @@ -16,6 +16,7 @@ from git import InvalidGitRepositoryError, Repo from rich.console import Group from rich.panel import Panel +from rich.prompt import Confirm from rich.table import Table from rich.text import Text @@ -26,7 +27,6 @@ from codeflash.code_utils.git_utils import get_git_remotes from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell from codeflash.telemetry.posthog_cf import ph -from rich.prompt import Confirm class ProjectLanguage(Enum): diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index feb65f645..dc198b0f8 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1577,9 +1577,11 @@ def get_opt_review_metrics( Returns: Markdown-formatted string with code blocks showing calling functions. + """ - from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.registry import get_language_support + from codeflash.models.models import FunctionParent start_time = time.perf_counter() @@ -1596,19 +1598,19 @@ def get_opt_review_metrics( else: function_name, class_name = qualified_name_split[1], qualified_name_split[0] - # Create a FunctionInfo for the function + # Create a FunctionToOptimize for the function # We don't have full line info here, so we'll use defaults - parents = () + parents: list[FunctionParent] = [] if class_name: - parents = (ParentInfo(name=class_name, type="ClassDef"),) + parents = [FunctionParent(name=class_name, type="ClassDef")] - func_info = FunctionInfo( - name=function_name, + func_info = FunctionToOptimize( + function_name=function_name, file_path=file_path, - start_line=1, - end_line=1, parents=parents, - language=language, + starting_line=1, + ending_line=1, + language=str(language), ) # Find references using language support @@ -1618,9 +1620,7 @@ def get_opt_review_metrics( return "" # Format references as markdown code blocks - calling_fns_details = _format_references_as_markdown( - references, file_path, project_root, language - ) + calling_fns_details = _format_references_as_markdown(references, file_path, project_root, language) except Exception as e: logger.debug(f"Error getting function references: {e}") @@ -1631,9 +1631,7 @@ def get_opt_review_metrics( return calling_fns_details -def _format_references_as_markdown( - references: list, file_path: Path, project_root: Path, language: Language -) -> str: +def _format_references_as_markdown(references: list, file_path: Path, project_root: Path, language: Language) -> str: """Format references as markdown code blocks with calling function code. Args: @@ -1644,6 +1642,7 @@ def _format_references_as_markdown( Returns: Markdown-formatted string. + """ # Group references by file refs_by_file: dict[Path, list] = {} @@ -1728,11 +1727,11 @@ def _extract_calling_function(source_code: str, function_name: str, ref_line: in Returns: Source code of the function, or None if not found. + """ if language == Language.PYTHON: return _extract_calling_function_python(source_code, function_name, ref_line) - else: - return _extract_calling_function_js(source_code, function_name, ref_line) + return _extract_calling_function_js(source_code, function_name, ref_line) def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None: @@ -1766,6 +1765,7 @@ def _extract_calling_function_js(source_code: str, function_name: str, ref_line: Returns: Source code of the function, or None if not found. + """ try: from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 6a57b61e1..b979dc37e 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -496,7 +496,7 @@ def replace_function_definitions_for_language( """ from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo, Language, ParentInfo + from codeflash.languages.base import Language original_source_code: str = module_abspath.read_text(encoding="utf8") code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) @@ -523,25 +523,15 @@ def replace_function_definitions_for_language( and function_to_optimize.ending_line and function_to_optimize.file_path == module_abspath ): - parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents) - func_info = FunctionInfo( - name=function_to_optimize.function_name, - file_path=module_abspath, - start_line=function_to_optimize.starting_line, - end_line=function_to_optimize.ending_line, - parents=parents, - is_async=function_to_optimize.is_async, - language=language, - ) # Extract just the target function from the optimized code optimized_func = _extract_function_from_code( lang_support, code_to_apply, function_to_optimize.function_name, module_abspath ) if optimized_func: - new_code = lang_support.replace_function(original_source_code, func_info, optimized_func) + new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func) else: # Fallback: use the entire optimized code (for simple single-function files) - new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply) + new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply) else: # For helper files or when we don't have precise line info: # Find each function by name in both original and optimized code diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 0fcc24ae6..03c7abef2 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -13,7 +13,6 @@ from codeflash.code_utils.code_utils import exit_with_message from codeflash.code_utils.formatter import format_code from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc -from codeflash.languages.base import Language from codeflash.languages.registry import get_language_support_by_common_formatters from codeflash.lsp.helpers import is_LSP_enabled @@ -44,9 +43,9 @@ def check_formatter_installed( logger.debug(f"Could not determine language for formatter: {formatter_cmds}") return True - if lang_support.language == Language.PYTHON: + if str(lang_support.language) == "python": tmp_code = """print("hello world")""" - elif lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT): + elif str(lang_support.language) in ("javascript", "typescript"): tmp_code = "console.log('hello world');" else: return True diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 7bf1f4cf3..4bfd96104 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -13,7 +13,6 @@ import isort from codeflash.cli_cmds.console import console, logger -from codeflash.languages.registry import get_language_support from codeflash.lsp.helpers import is_LSP_enabled @@ -43,6 +42,8 @@ def split_lines(text: str) -> list[str]: def apply_formatter_cmds( cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True ) -> tuple[Path, str, bool]: + from codeflash.languages.registry import get_language_support + if not path.exists(): msg = f"File {path} does not exist. Cannot apply formatter commands." raise FileNotFoundError(msg) @@ -90,6 +91,8 @@ def is_diff_line(line: str) -> bool: def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str: + from codeflash.languages.registry import get_language_support + formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" if formatter_name == "disabled": # nothing to do if no formatter provided return re.sub(r"\n{2,}", "\n\n", generated_test_source) @@ -114,6 +117,8 @@ def format_code( print_status: bool = True, exit_on_failure: bool = True, ) -> str: + from codeflash.languages.registry import get_language_support + if is_LSP_enabled(): exit_on_failure = False diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 0479d4178..466d2aa46 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -234,27 +234,13 @@ def get_code_optimization_context_for_language( """ from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo, ParentInfo # Get language support for this function language = Language(function_to_optimize.language) lang_support = get_language_support(language) - # Convert FunctionToOptimize to FunctionInfo for language support - parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents) - func_info = FunctionInfo( - name=function_to_optimize.function_name, - file_path=function_to_optimize.file_path, - start_line=function_to_optimize.starting_line or 1, - end_line=function_to_optimize.ending_line or 1, - parents=parents, - is_async=function_to_optimize.is_async, - is_method=len(function_to_optimize.parents) > 0, - language=language, - ) - # Extract code context using language support - code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path) + code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path) # Build imports string if available imports_code = "\n".join(code_context.imports) if code_context.imports else "" diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 3cde1d6d2..cd0a82605 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -29,7 +29,6 @@ ) from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args -from codeflash.languages import is_javascript, is_python from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType if TYPE_CHECKING: @@ -589,7 +588,7 @@ def discover_tests_for_language( """ from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo, Language, ParentInfo + from codeflash.languages.base import Language try: lang_support = get_language_support(Language(language)) @@ -597,34 +596,20 @@ def discover_tests_for_language( logger.warning(f"Unsupported language {language}, returning empty test map") return {}, 0, 0 - # Convert FunctionToOptimize to FunctionInfo for the language support API - # Also build a mapping from simple qualified_name to full qualified_name_with_modules - function_infos: list[FunctionInfo] = [] + # Collect all functions and build a mapping from simple qualified_name to full qualified_name_with_modules + all_functions: list[FunctionToOptimize] = [] simple_to_full_name: dict[str, str] = {} if file_to_funcs_to_optimize: for funcs in file_to_funcs_to_optimize.values(): for func in funcs: - parents = tuple(ParentInfo(p.name, p.type) for p in func.parents) - func_info = FunctionInfo( - name=func.function_name, - file_path=func.file_path, - start_line=func.starting_line or 0, - end_line=func.ending_line or 0, - start_col=func.starting_col, - end_col=func.ending_col, - is_async=func.is_async, - is_method=bool(func.parents and any(p.type == "ClassDef" for p in func.parents)), - parents=parents, - language=Language(language), - ) - function_infos.append(func_info) + all_functions.append(func) # Map simple qualified_name to full qualified_name_with_modules_from_root - simple_to_full_name[func_info.qualified_name] = func.qualified_name_with_modules_from_root( + simple_to_full_name[func.qualified_name] = func.qualified_name_with_modules_from_root( cfg.project_root_path ) # Use language support to discover tests - test_map = lang_support.discover_tests(cfg.tests_root, function_infos) + test_map = lang_support.discover_tests(cfg.tests_root, all_functions) # Convert TestInfo back to FunctionCalledInTest format # Use the full qualified name (with modules) as the key for consistency with Python @@ -656,6 +641,8 @@ def discover_unit_tests( discover_only_these_tests: list[Path] | None = None, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None, ) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: + from codeflash.languages import is_javascript, is_python + # Detect language from functions being optimized language = _detect_language_from_functions(file_to_funcs_to_optimize) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 830380bde..6db677ca7 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -12,6 +12,7 @@ import git import libcst as cst +from pydantic import Field from pydantic.dataclasses import dataclass from rich.tree import Tree @@ -26,9 +27,8 @@ from codeflash.code_utils.env_utils import get_pr_number from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name from codeflash.discovery.discover_unit_tests import discover_unit_tests -from codeflash.languages import get_language_support, get_supported_extensions -from codeflash.languages.base import Language -from codeflash.languages.registry import is_language_supported +from codeflash.languages.language_enum import Language +from codeflash.languages.registry import get_language_support, get_supported_extensions, is_language_supported from codeflash.lsp.helpers import is_LSP_enabled from codeflash.models.models import FunctionParent from codeflash.telemetry.posthog_cf import ph @@ -134,17 +134,23 @@ def generic_visit(self, node: ast.AST) -> None: class FunctionToOptimize: """Represent a function that is a candidate for optimization. + This is the canonical dataclass for representing functions across all languages + (Python, JavaScript, TypeScript). It captures all information needed to identify, + locate, and work with a function. + Attributes ---------- function_name: The name of the function. file_path: The absolute file path where the function is located. parents: A list of parent scopes, which could be classes or functions. - starting_line: The starting line number of the function in the file. - ending_line: The ending line number of the function in the file. - starting_col: The starting column offset (for precise location in multi-line contexts). - ending_col: The ending column offset (for precise location in multi-line contexts). + starting_line: The starting line number of the function in the file (1-indexed). + ending_line: The ending line number of the function in the file (1-indexed). + starting_col: The starting column offset (0-indexed, for precise location). + ending_col: The ending column offset (0-indexed, for precise location). is_async: Whether this function is defined as async. + is_method: Whether this is a method (belongs to a class). language: The programming language of this function (default: "python"). + doc_start_line: Line where docstring/JSDoc starts (or None if no doc comment). The qualified_name property provides the full name of the function, including any parent class or function names. The qualified_name_with_modules_from_root @@ -154,13 +160,15 @@ class FunctionToOptimize: function_name: str file_path: Path - parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] + parents: list[FunctionParent] = Field(default_factory=list) # list[ClassDef | FunctionDef | AsyncFunctionDef] starting_line: Optional[int] = None ending_line: Optional[int] = None starting_col: Optional[int] = None # Column offset for precise location ending_col: Optional[int] = None # Column offset for precise location is_async: bool = False + is_method: bool = False # Whether this is a method (belongs to a class) language: str = "python" # Language identifier for multi-language support + doc_start_line: Optional[int] = None # Line where docstring/JSDoc starts @property def top_level_parent_name(self) -> str: @@ -175,10 +183,9 @@ def class_name(self) -> str | None: return None def __str__(self) -> str: - return ( - f"{self.file_path}:{'.'.join([p.name for p in self.parents])}" - f"{'.' if self.parents else ''}{self.function_name}" - ) + qualified = f"{'.'.join([p.name for p in self.parents])}{'.' if self.parents else ''}{self.function_name}" + line_info = f":{self.starting_line}-{self.ending_line}" if self.starting_line and self.ending_line else "" + return f"{self.file_path}:{qualified}{line_info}" @property def qualified_name(self) -> str: @@ -195,8 +202,8 @@ def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: def from_function_info(cls, func_info: FunctionInfo) -> FunctionToOptimize: """Create a FunctionToOptimize from a FunctionInfo instance. - This enables interoperability between the language-agnostic FunctionInfo - and the FunctionToOptimize dataclass used throughout the codebase. + This is a temporary method for backward compatibility during migration. + Once FunctionInfo is fully removed, this method can be deleted. """ parents = [FunctionParent(name=p.name, type=p.type) for p in func_info.parents] return cls( @@ -208,7 +215,9 @@ def from_function_info(cls, func_info: FunctionInfo) -> FunctionToOptimize: starting_col=func_info.start_col, ending_col=func_info.end_col, is_async=func_info.is_async, + is_method=func_info.is_method, language=func_info.language.value, + doc_start_line=func_info.doc_start_line, ) diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index e99502c1d..0c4e9d07c 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -19,7 +19,6 @@ from codeflash.languages.base import ( CodeContext, - FunctionInfo, HelperFunction, Language, LanguageSupport, @@ -27,6 +26,17 @@ TestInfo, TestResult, ) + + +# Lazy import for FunctionInfo to avoid circular imports +def __getattr__(name: str): + if name == "FunctionInfo": + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + + return FunctionToOptimize + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + from codeflash.languages.current import ( current_language, current_language_support, diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 6e3ea4417..5fb7f99ce 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -2,111 +2,35 @@ This module defines the core abstractions that all language implementations must follow. The LanguageSupport protocol defines the interface that each language must implement, -while the dataclasses define language-agnostic representations of code constructs. +while FunctionToOptimize is the canonical representation of functions across all languages. """ from __future__ import annotations from dataclasses import dataclass, field -from enum import Enum from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from collections.abc import Sequence from pathlib import Path + from codeflash.discovery.functions_to_optimize import FunctionToOptimize -class Language(str, Enum): - """Supported programming languages.""" +from codeflash.languages.language_enum import Language +from codeflash.models.models import FunctionParent - PYTHON = "python" - JAVASCRIPT = "javascript" - TYPESCRIPT = "typescript" +# Backward compatibility aliases - ParentInfo is now FunctionParent +ParentInfo = FunctionParent - def __str__(self) -> str: - return self.value +# Lazy import for FunctionInfo to avoid circular imports +# This allows `from codeflash.languages.base import FunctionInfo` to work at runtime +def __getattr__(name: str) -> Any: + if name == "FunctionInfo": + from codeflash.discovery.functions_to_optimize import FunctionToOptimize -@dataclass(frozen=True) -class ParentInfo: - """Parent scope information for nested functions/methods. - - Represents the parent class or function that contains a nested function. - Used to construct the qualified name of a function. - - Attributes: - name: The name of the parent scope (class name or function name). - type: The type of parent ("ClassDef", "FunctionDef", "AsyncFunctionDef", etc.). - - """ - - name: str - type: str # "ClassDef", "FunctionDef", "AsyncFunctionDef", etc. - - def __str__(self) -> str: - return f"{self.type}:{self.name}" - - -@dataclass(frozen=True) -class FunctionInfo: - """Language-agnostic representation of a function to optimize. - - This class captures all the information needed to identify, locate, and - work with a function across different programming languages. - - Attributes: - name: The simple function name (e.g., "add"). - file_path: Absolute path to the file containing the function. - start_line: Starting line number (1-indexed). - end_line: Ending line number (1-indexed, inclusive). - parents: List of parent scopes (for nested functions/methods). - is_async: Whether this is an async function. - is_method: Whether this is a method (belongs to a class). - language: The programming language. - start_col: Starting column (0-indexed), optional for more precise location. - end_col: Ending column (0-indexed), optional. - - """ - - name: str - file_path: Path - start_line: int - end_line: int - parents: tuple[ParentInfo, ...] = () - is_async: bool = False - is_method: bool = False - language: Language = Language.PYTHON - start_col: int | None = None - end_col: int | None = None - doc_start_line: int | None = None # Line where docstring/JSDoc starts (or None if no doc comment) - - @property - def qualified_name(self) -> str: - """Full qualified name including parent scopes. - - For a method `add` in class `Calculator`, returns "Calculator.add". - For nested functions, includes all parent scopes. - """ - if not self.parents: - return self.name - parent_path = ".".join(parent.name for parent in self.parents) - return f"{parent_path}.{self.name}" - - @property - def class_name(self) -> str | None: - """Get the immediate parent class name, if any.""" - for parent in reversed(self.parents): - if parent.type == "ClassDef": - return parent.name - return None - - @property - def top_level_parent_name(self) -> str: - """Get the top-level parent name, or function name if no parents.""" - return self.parents[0].name if self.parents else self.name - - def __str__(self) -> str: - return f"FunctionInfo({self.qualified_name} at {self.file_path}:{self.start_line}-{self.end_line})" + return FunctionToOptimize + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @dataclass @@ -333,7 +257,7 @@ def comment_prefix(self) -> str: def discover_functions( self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None - ) -> list[FunctionInfo]: + ) -> list[FunctionToOptimize]: """Find all optimizable functions in a file. Args: @@ -341,12 +265,14 @@ def discover_functions( filter_criteria: Optional criteria to filter functions. Returns: - List of FunctionInfo objects for discovered functions. + List of FunctionToOptimize objects for discovered functions. """ ... - def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]: + def discover_tests( + self, test_root: Path, source_functions: Sequence[FunctionToOptimize] + ) -> dict[str, list[TestInfo]]: """Map source functions to their tests via static analysis. Args: @@ -361,7 +287,7 @@ def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInf # === Code Analysis === - def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext: + def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: """Extract function code and its dependencies. Args: @@ -375,7 +301,7 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul """ ... - def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]: + def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: """Find helper functions called by the target function. Args: @@ -389,7 +315,7 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l ... def find_references( - self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500 + self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500 ) -> list[ReferenceInfo]: """Find all references (call sites) to a function across the codebase. @@ -413,12 +339,12 @@ def find_references( # === Code Transformation === - def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str: + def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: """Replace a function in source code with new implementation. Args: source: Original source code. - function: FunctionInfo identifying the function to replace. + function: FunctionToOptimize identifying the function to replace. new_source: New function source code. Returns: @@ -474,7 +400,7 @@ def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResu # === Instrumentation === - def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str: + def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str: """Add behavior instrumentation to capture inputs/outputs. Args: @@ -487,7 +413,7 @@ def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo] """ ... - def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str: + def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: """Add timing instrumentation to test code. Args: @@ -664,7 +590,9 @@ def instrument_existing_test( """ ... - def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool: + def instrument_source_for_line_profiler( + self, func_info: FunctionToOptimize, line_profiler_output_file: Path + ) -> bool: """Instrument source code before line profiling.""" ... @@ -731,17 +659,14 @@ def run_benchmarking_tests( ... -def convert_parents_to_tuple(parents: list | tuple) -> tuple[ParentInfo, ...]: - """Convert a list of parent objects to a tuple of ParentInfo. - - This helper handles conversion from the existing FunctionParent - dataclass to the new ParentInfo dataclass. +def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]: + """Convert a list of parent objects to a tuple of FunctionParent. Args: parents: List or tuple of parent objects with name and type attributes. Returns: - Tuple of ParentInfo objects. + Tuple of FunctionParent objects. """ - return tuple(ParentInfo(name=p.name, type=p.type) for p in parents) + return tuple(FunctionParent(name=p.name, type=p.type) for p in parents) diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py index 18e096a74..3cf0d8b30 100644 --- a/codeflash/languages/javascript/find_references.py +++ b/codeflash/languages/javascript/find_references.py @@ -73,10 +73,7 @@ class ReferenceFinder: from codeflash.discovery.functions_to_optimize import FunctionToOptimize func = FunctionToOptimize( - function_name="myHelper", - file_path=Path("/my/project/src/utils.ts"), - parents=[], - language="javascript" + function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript" ) finder = ReferenceFinder(project_root=Path("/my/project")) references = finder.find_references(func) diff --git a/codeflash/languages/javascript/import_resolver.py b/codeflash/languages/javascript/import_resolver.py index 49452ec51..f4def166f 100644 --- a/codeflash/languages/javascript/import_resolver.py +++ b/codeflash/languages/javascript/import_resolver.py @@ -12,7 +12,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from codeflash.languages.base import FunctionInfo, HelperFunction + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.base import HelperFunction from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer logger = logging.getLogger(__name__) @@ -302,7 +303,7 @@ def __init__(self, project_root: Path, import_resolver: ImportResolver) -> None: def find_helpers( self, - function: FunctionInfo, + function: FunctionToOptimize, source: str, analyzer: TreeSitterAnalyzer, imports: list[ImportInfo], @@ -505,7 +506,7 @@ def _find_helpers_recursive( Dictionary mapping file paths to lists of helper functions. """ - from codeflash.languages.base import FunctionInfo + from codeflash.languages.base import FunctionToOptimize from codeflash.languages.treesitter_utils import get_analyzer_for_file if context.current_depth >= context.max_depth: @@ -525,9 +526,13 @@ def _find_helpers_recursive( analyzer = get_analyzer_for_file(file_path) imports = analyzer.find_imports(source) - # Create FunctionInfo for the helper - func_info = FunctionInfo( - name=helper.name, file_path=file_path, start_line=helper.start_line, end_line=helper.end_line, parents=() + # Create FunctionToOptimize for the helper + func_info = FunctionToOptimize( + function_name=helper.name, + file_path=file_path, + parents=[], + starting_line=helper.start_line, + ending_line=helper.end_line, ) # Recursively find helpers diff --git a/codeflash/languages/javascript/line_profiler.py b/codeflash/languages/javascript/line_profiler.py index 757dd9282..c2ea2f495 100644 --- a/codeflash/languages/javascript/line_profiler.py +++ b/codeflash/languages/javascript/line_profiler.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.languages.base import FunctionInfo + from codeflash.discovery.functions_to_optimize import FunctionToOptimize logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def __init__(self, output_file: Path) -> None: self.output_file = output_file self.profiler_var = "__codeflash_line_profiler__" - def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str: + def instrument_source(self, source: str, file_path: Path, functions: list[FunctionToOptimize]) -> str: """Instrument JavaScript source code with line profiling. Adds profiling instrumentation to track line-level execution for the @@ -171,7 +171,7 @@ def _generate_profiler_save(self) -> str: if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // Don't keep process alive """ - def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]: + def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_path: Path) -> list[str]: """Instrument a single function with line profiling. Args: diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 8cdd54e47..d9ca2cfea 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -12,18 +12,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from codeflash.languages.base import ( - CodeContext, - FunctionFilterCriteria, - FunctionInfo, - HelperFunction, - Language, - ParentInfo, - TestInfo, - TestResult, -) +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, Language, TestInfo, TestResult from codeflash.languages.registry import register_language from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file +from codeflash.models.models import FunctionParent if TYPE_CHECKING: from collections.abc import Sequence @@ -72,7 +65,7 @@ def comment_prefix(self) -> str: def discover_functions( self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None - ) -> list[FunctionInfo]: + ) -> list[FunctionToOptimize]: """Find all optimizable functions in a JavaScript file. Uses tree-sitter to parse the file and find functions. @@ -82,7 +75,7 @@ def discover_functions( filter_criteria: Optional criteria to filter functions. Returns: - List of FunctionInfo objects for discovered functions. + List of FunctionToOptimize objects for discovered functions. """ criteria = filter_criteria or FunctionFilterCriteria() @@ -99,7 +92,7 @@ def discover_functions( source, include_methods=criteria.include_methods, include_arrow_functions=True, require_name=True ) - functions: list[FunctionInfo] = [] + functions: list[FunctionToOptimize] = [] for func in tree_functions: # Check for return statement if required if criteria.require_return and not analyzer.has_return_statement(func, source): @@ -110,24 +103,24 @@ def discover_functions( continue # Build parents list - parents: list[ParentInfo] = [] + parents: list[FunctionParent] = [] if func.class_name: - parents.append(ParentInfo(name=func.class_name, type="ClassDef")) + parents.append(FunctionParent(name=func.class_name, type="ClassDef")) if func.parent_function: - parents.append(ParentInfo(name=func.parent_function, type="FunctionDef")) + parents.append(FunctionParent(name=func.parent_function, type="FunctionDef")) functions.append( - FunctionInfo( - name=func.name, + FunctionToOptimize( + function_name=func.name, file_path=file_path, - start_line=func.start_line, - end_line=func.end_line, - start_col=func.start_col, - end_col=func.end_col, - parents=tuple(parents), + parents=parents, + starting_line=func.start_line, + ending_line=func.end_line, + starting_col=func.start_col, + ending_col=func.end_col, is_async=func.is_async, is_method=func.is_method, - language=self.language, + language=str(self.language), doc_start_line=func.doc_start_line, ) ) @@ -138,7 +131,7 @@ def discover_functions( logger.warning("Failed to parse %s: %s", file_path, e) return [] - def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionInfo]: + def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionToOptimize]: """Find all functions in source code string. Uses tree-sitter to parse the source and find functions. @@ -148,7 +141,7 @@ def discover_functions_from_source(self, source: str, file_path: Path | None = N file_path: Optional file path for context (used for language detection). Returns: - List of FunctionInfo objects for discovered functions. + List of FunctionToOptimize objects for discovered functions. """ try: @@ -162,27 +155,27 @@ def discover_functions_from_source(self, source: str, file_path: Path | None = N source, include_methods=True, include_arrow_functions=True, require_name=True ) - functions: list[FunctionInfo] = [] + functions: list[FunctionToOptimize] = [] for func in tree_functions: # Build parents list - parents: list[ParentInfo] = [] + parents: list[FunctionParent] = [] if func.class_name: - parents.append(ParentInfo(name=func.class_name, type="ClassDef")) + parents.append(FunctionParent(name=func.class_name, type="ClassDef")) if func.parent_function: - parents.append(ParentInfo(name=func.parent_function, type="FunctionDef")) + parents.append(FunctionParent(name=func.parent_function, type="FunctionDef")) functions.append( - FunctionInfo( - name=func.name, + FunctionToOptimize( + function_name=func.name, file_path=file_path or Path("unknown"), - start_line=func.start_line, - end_line=func.end_line, - start_col=func.start_col, - end_col=func.end_col, - parents=tuple(parents), + parents=parents, + starting_line=func.start_line, + ending_line=func.end_line, + starting_col=func.start_col, + ending_col=func.end_col, is_async=func.is_async, is_method=func.is_method, - language=self.language, + language=str(self.language), doc_start_line=func.doc_start_line, ) ) @@ -204,7 +197,9 @@ def _get_test_patterns(self) -> list[str]: """ return ["*.test.js", "*.test.jsx", "*.spec.js", "*.spec.jsx", "__tests__/**/*.js", "__tests__/**/*.jsx"] - def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]: + def discover_tests( + self, test_root: Path, source_functions: Sequence[FunctionToOptimize] + ) -> dict[str, list[TestInfo]]: """Map source functions to their tests via static analysis. For JavaScript, this uses static analysis to find test files @@ -288,7 +283,7 @@ def _walk_for_jest_tests(self, node: Any, source_bytes: bytes, test_names: list[ # === Code Analysis === - def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext: + def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: """Extract function code and its dependencies. Uses tree-sitter to analyze imports and find helper functions. @@ -315,16 +310,16 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul tree_functions = analyzer.find_functions(source, include_methods=True, include_arrow_functions=True) target_func = None for func in tree_functions: - if func.name == function.name and func.start_line == function.start_line: + if func.name == function.function_name and func.start_line == function.starting_line: target_func = func break # Extract the function source, including JSDoc if present lines = source.splitlines(keepends=True) - if function.start_line and function.end_line: + if function.starting_line and function.ending_line: # Use doc_start_line if available, otherwise fall back to start_line - effective_start = (target_func.doc_start_line if target_func else None) or function.start_line - target_lines = lines[effective_start - 1 : function.end_line] + effective_start = (target_func.doc_start_line if target_func else None) or function.starting_line + target_lines = lines[effective_start - 1 : function.ending_line] target_code = "".join(target_lines) else: target_code = "" @@ -340,7 +335,7 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul if class_name: # Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields - class_info = self._find_class_definition(source, class_name, analyzer, function.name) + class_info = self._find_class_definition(source, class_name, analyzer, function.function_name) if class_info: class_jsdoc, class_indent, constructor_code, fields_code = class_info # Build the class body with fields, constructor, and target method @@ -401,7 +396,7 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul # If not, raise an error to fail the optimization early if target_code and not self.validate_syntax(target_code): error_msg = ( - f"Extracted code for {function.name} is not syntactically valid JavaScript. " + f"Extracted code for {function.function_name} is not syntactically valid JavaScript. " f"Cannot proceed with optimization." ) logger.error(error_msg) @@ -550,7 +545,12 @@ def _extract_class_context( return (constructor_code, fields_code) def _find_helper_functions( - self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path + self, + function: FunctionToOptimize, + source: str, + analyzer: TreeSitterAnalyzer, + imports: list[Any], + module_root: Path, ) -> list[HelperFunction]: """Find helper functions called by the target function. @@ -575,7 +575,7 @@ def _find_helper_functions( # Find the target function's tree-sitter node target_func = None for func in all_functions: - if func.name == function.name and func.start_line == function.start_line: + if func.name == function.function_name and func.start_line == function.starting_line: target_func = func break @@ -591,7 +591,7 @@ def _find_helper_functions( # Match calls to functions in the same file for func in all_functions: - if func.name in calls_set and func.name != function.name: + if func.name in calls_set and func.name != function.function_name: # Extract source including JSDoc if present effective_start = func.doc_start_line or func.start_line helper_lines = lines[effective_start - 1 : func.end_line] @@ -721,7 +721,12 @@ def _find_referenced_globals( return "\n".join(global_lines) def _extract_type_definitions_context( - self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path + self, + function: FunctionToOptimize, + source: str, + analyzer: TreeSitterAnalyzer, + imports: list[Any], + module_root: Path, ) -> tuple[str, set[str]]: """Extract type definitions used by the function for read-only context. @@ -747,7 +752,7 @@ def _extract_type_definitions_context( """ # Extract type names from function parameters and return type - type_names = analyzer.extract_type_annotations(source, function.name, function.start_line or 1) + type_names = analyzer.extract_type_annotations(source, function.function_name, function.starting_line or 1) # If this is a class method, also extract types from class fields if function.is_method and function.parents: @@ -945,7 +950,7 @@ def _find_imported_type_definitions( return found_definitions - def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]: + def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: """Find helper functions called by the target function. Args: @@ -962,11 +967,11 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l imports = analyzer.find_imports(source) return self._find_helper_functions(function, source, analyzer, imports, project_root) except Exception as e: - logger.warning("Failed to find helpers for %s: %s", function.name, e) + logger.warning("Failed to find helpers for %s: %s", function.function_name, e) return [] def find_references( - self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500 + self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500 ) -> list[ReferenceInfo]: """Find all references (call sites) to a function across the codebase. @@ -983,14 +988,12 @@ def find_references( List of ReferenceInfo objects describing each reference location. """ - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import ReferenceInfo from codeflash.languages.javascript.find_references import ReferenceFinder try: finder = ReferenceFinder(project_root) - func_to_optimize = FunctionToOptimize.from_function_info(function) - refs = finder.find_references(func_to_optimize, max_files=max_files) + refs = finder.find_references(function, max_files=max_files) # Convert to ReferenceInfo and filter out tests result: list[ReferenceInfo] = [] @@ -1020,12 +1023,12 @@ def find_references( return result except Exception as e: - logger.warning("Failed to find references for %s: %s", function.name, e) + logger.warning("Failed to find references for %s: %s", function.function_name, e) return [] # === Code Transformation === - def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str: + def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: """Replace a function in source code with new implementation. Uses node-based replacement to extract the method body from the optimized code @@ -1037,7 +1040,7 @@ def replace_function(self, source: str, function: FunctionInfo, new_source: str) Args: source: Original source code. - function: FunctionInfo identifying the function to replace. + function: FunctionToOptimize identifying the function to replace. new_source: New source code containing the optimized function. Returns: @@ -1045,13 +1048,13 @@ def replace_function(self, source: str, function: FunctionInfo, new_source: str) if new_source is empty or invalid. """ - if function.start_line is None or function.end_line is None: - logger.error("Function %s has no line information", function.name) + if function.starting_line is None or function.ending_line is None: + logger.error("Function %s has no line information", function.function_name) return source # If new_source is empty or whitespace-only, return original unchanged if not new_source or not new_source.strip(): - logger.warning("Empty new_source provided for %s, returning original", function.name) + logger.warning("Empty new_source provided for %s, returning original", function.function_name) return source # Get analyzer for parsing @@ -1065,19 +1068,21 @@ def replace_function(self, source: str, function: FunctionInfo, new_source: str) stripped_new_source = new_source.strip() if stripped_new_source.startswith("/**"): # new_source includes JSDoc, use full replacement to apply the new JSDoc - if not self._contains_function_declaration(new_source, function.name, analyzer): - logger.warning("new_source does not contain function %s, returning original", function.name) + if not self._contains_function_declaration(new_source, function.function_name, analyzer): + logger.warning("new_source does not contain function %s, returning original", function.function_name) return source return self._replace_function_text_based(source, function, new_source, analyzer) # Extract just the method body from the new source - new_body = self._extract_function_body(new_source, function.name, analyzer) + new_body = self._extract_function_body(new_source, function.function_name, analyzer) if new_body is None: - logger.warning("Could not extract body for %s from optimized code, using full replacement", function.name) + logger.warning( + "Could not extract body for %s from optimized code, using full replacement", function.function_name + ) # Verify that new_source contains actual code before falling back to text replacement # This prevents deletion of the original function when new_source is invalid - if not self._contains_function_declaration(new_source, function.name, analyzer): - logger.warning("new_source does not contain function %s, returning original", function.name) + if not self._contains_function_declaration(new_source, function.function_name, analyzer): + logger.warning("new_source does not contain function %s, returning original", function.function_name) return source return self._replace_function_text_based(source, function, new_source, analyzer) @@ -1205,7 +1210,7 @@ def find_function_node(node, target_name: str): return source_bytes[body_node.start_byte : body_node.end_byte].decode("utf8") def _replace_function_body( - self, source: str, function: FunctionInfo, new_body: str, analyzer: TreeSitterAnalyzer + self, source: str, function: FunctionToOptimize, new_body: str, analyzer: TreeSitterAnalyzer ) -> str: """Replace the body of a function in source code with new body content. @@ -1213,7 +1218,7 @@ def _replace_function_body( Args: source: Original source code. - function: FunctionInfo identifying the function to modify. + function: FunctionToOptimize identifying the function to modify. new_body: New body content (including braces). analyzer: TreeSitterAnalyzer for parsing. @@ -1264,9 +1269,9 @@ def find_function_at_line(node, target_name: str, target_line: int): return None - func_node = find_function_at_line(tree.root_node, function.name, function.start_line) + func_node = find_function_at_line(tree.root_node, function.function_name, function.starting_line) if not func_node: - logger.warning("Could not find function %s at line %s", function.name, function.start_line) + logger.warning("Could not find function %s at line %s", function.function_name, function.starting_line) return source # Find the body node in the original @@ -1278,7 +1283,7 @@ def find_function_at_line(node, target_name: str, target_line: int): break if not body_node: - logger.warning("Could not find body for function %s", function.name) + logger.warning("Could not find body for function %s", function.function_name) return source # Get the indentation of the original body's opening brace @@ -1346,7 +1351,7 @@ def find_function_at_line(node, target_name: str, target_line: int): return result.decode("utf8") def _replace_function_text_based( - self, source: str, function: FunctionInfo, new_source: str, analyzer: TreeSitterAnalyzer + self, source: str, function: FunctionToOptimize, new_source: str, analyzer: TreeSitterAnalyzer ) -> str: """Fallback text-based replacement when node-based replacement fails. @@ -1354,7 +1359,7 @@ def _replace_function_text_based( Args: source: Original source code. - function: FunctionInfo identifying the function to replace. + function: FunctionToOptimize identifying the function to replace. new_source: New function source code. analyzer: TreeSitterAnalyzer for parsing. @@ -1371,16 +1376,16 @@ def _replace_function_text_based( tree_functions = analyzer.find_functions(source, include_methods=True, include_arrow_functions=True) target_func = None for func in tree_functions: - if func.name == function.name and func.start_line == function.start_line: + if func.name == function.function_name and func.start_line == function.starting_line: target_func = func break # Use doc_start_line if available, otherwise fall back to start_line - effective_start = (target_func.doc_start_line if target_func else None) or function.start_line + effective_start = (target_func.doc_start_line if target_func else None) or function.starting_line # Get indentation from original function's first line - if function.start_line <= len(lines): - original_first_line = lines[function.start_line - 1] + if function.starting_line <= len(lines): + original_first_line = lines[function.starting_line - 1] original_indent = len(original_first_line) - len(original_first_line.lstrip()) else: original_indent = 0 @@ -1423,7 +1428,7 @@ def _replace_function_text_based( # Build result before = lines[: effective_start - 1] - after = lines[function.end_line :] + after = lines[function.ending_line :] result_lines = before + new_lines + after return "".join(result_lines) @@ -1574,7 +1579,7 @@ def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResu # === Instrumentation === def instrument_for_behavior( - self, source: str, functions: Sequence[FunctionInfo], output_file: Path | None = None + self, source: str, functions: Sequence[FunctionToOptimize], output_file: Path | None = None ) -> str: """Add behavior instrumentation to capture inputs/outputs. @@ -1603,7 +1608,7 @@ def instrument_for_behavior( tracer = JavaScriptTracer(output_file) return tracer.instrument_source(source, functions[0].file_path, list(functions)) - def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str: + def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: """Add timing instrumentation to test code. For JavaScript/Jest, we can use Jest's built-in timing or add custom timing. @@ -1964,7 +1969,7 @@ def instrument_existing_test( def instrument_source_for_line_profiler( # TODO: use the context to instrument helper files also self, - func_info: FunctionInfo, + func_info: FunctionToOptimize, line_profiler_output_file: Path, ) -> bool: from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler diff --git a/codeflash/languages/javascript/tracer.py b/codeflash/languages/javascript/tracer.py index 66b97b488..7632db50b 100644 --- a/codeflash/languages/javascript/tracer.py +++ b/codeflash/languages/javascript/tracer.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.languages.base import FunctionInfo + from codeflash.discovery.functions_to_optimize import FunctionToOptimize logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def __init__(self, output_db: Path) -> None: self.output_db = output_db self.tracer_var = "__codeflash_tracer__" - def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str: + def instrument_source(self, source: str, file_path: Path, functions: list[FunctionToOptimize]) -> str: """Instrument JavaScript source code with function tracing. Wraps specified functions to capture their inputs and outputs. @@ -269,7 +269,7 @@ def _generate_tracer_save(self) -> str: }}); """ - def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]: + def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_path: Path) -> list[str]: """Instrument a single function with tracing. Args: diff --git a/codeflash/languages/language_enum.py b/codeflash/languages/language_enum.py new file mode 100644 index 000000000..7ddded0fe --- /dev/null +++ b/codeflash/languages/language_enum.py @@ -0,0 +1,17 @@ +"""Language enum for multi-language support. + +This module is kept separate to avoid circular imports. +""" + +from enum import Enum + + +class Language(str, Enum): + """Supported programming languages.""" + + PYTHON = "python" + JAVASCRIPT = "javascript" + TYPESCRIPT = "typescript" + + def __str__(self) -> str: + return self.value diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index bf35c7777..8268ca8d3 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -6,13 +6,12 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import ( CodeContext, FunctionFilterCriteria, - FunctionInfo, HelperFunction, Language, - ParentInfo, ReferenceInfo, TestInfo, TestResult, @@ -64,7 +63,7 @@ def comment_prefix(self) -> str: def discover_functions( self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None - ) -> list[FunctionInfo]: + ) -> list[FunctionToOptimize]: """Find all optimizable functions in a Python file. Uses libcst to parse the file and find functions with return statements. @@ -74,12 +73,12 @@ def discover_functions( filter_criteria: Optional criteria to filter functions. Returns: - List of FunctionInfo objects for discovered functions. + List of FunctionToOptimize objects for discovered functions. """ import libcst as cst - from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionVisitor + from codeflash.discovery.functions_to_optimize import FunctionVisitor criteria = filter_criteria or FunctionFilterCriteria() @@ -96,7 +95,7 @@ def discover_functions( function_visitor = FunctionVisitor(file_path=str(file_path)) wrapper.visit(function_visitor) - functions: list[FunctionInfo] = [] + functions: list[FunctionToOptimize] = [] for func in function_visitor.functions: if not isinstance(func, FunctionToOptimize): continue @@ -113,23 +112,20 @@ def discover_functions( if criteria.require_return and func.starting_line is None: continue - # Convert FunctionToOptimize to FunctionInfo - parents = tuple(ParentInfo(name=p.name, type=p.type) for p in func.parents) - - functions.append( - FunctionInfo( - name=func.function_name, - file_path=file_path, - start_line=func.starting_line or 1, - end_line=func.ending_line or 1, - start_col=func.starting_col, - end_col=func.ending_col, - parents=parents, - is_async=func.is_async, - is_method=len(func.parents) > 0, - language=Language.PYTHON, - ) + # Add is_method field based on parents + func_with_is_method = FunctionToOptimize( + function_name=func.function_name, + file_path=file_path, + parents=func.parents, + starting_line=func.starting_line, + ending_line=func.ending_line, + starting_col=func.starting_col, + ending_col=func.ending_col, + is_async=func.is_async, + is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents), + language="python", ) + functions.append(func_with_is_method) return functions @@ -137,7 +133,9 @@ def discover_functions( logger.warning("Failed to discover functions in %s: %s", file_path, e) return [] - def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]: + def discover_tests( + self, test_root: Path, source_functions: Sequence[FunctionToOptimize] + ) -> dict[str, list[TestInfo]]: """Map source functions to their tests via static analysis. Args: @@ -172,7 +170,7 @@ def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInf # === Code Analysis === - def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext: + def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: """Extract function code and its dependencies. Uses jedi and libcst for Python code analysis. @@ -194,8 +192,8 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul # Extract the function source lines = source.splitlines(keepends=True) - if function.start_line and function.end_line: - target_lines = lines[function.start_line - 1 : function.end_line] + if function.starting_line and function.ending_line: + target_lines = lines[function.starting_line - 1 : function.ending_line] target_code = "".join(target_lines) else: target_code = "" @@ -222,7 +220,7 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul language=Language.PYTHON, ) - def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]: + def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: """Find helper functions called by the target function. Uses jedi for Python code analysis. @@ -296,11 +294,7 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l return helpers def find_references( - self, - function: FunctionInfo, - project_root: Path, - tests_root: Path | None = None, - max_files: int = 500, + self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500 ) -> list[ReferenceInfo]: """Find all references (call sites) to a function across the codebase. @@ -411,14 +405,14 @@ def find_references( # === Code Transformation === - def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str: + def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: """Replace a function in source code with new implementation. Uses libcst for Python code transformation. Args: source: Original source code. - function: FunctionInfo identifying the function to replace. + function: FunctionToOptimize identifying the function to replace. new_source: New function source code. Returns: @@ -585,7 +579,7 @@ def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResu # === Instrumentation === - def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str: + def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str: """Add behavior instrumentation to capture inputs/outputs. Args: @@ -600,7 +594,7 @@ def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo] # This is a pass-through for now return source - def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str: + def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: """Add timing instrumentation to test code. Args: @@ -841,7 +835,9 @@ def instrument_existing_test( mode=testing_mode, ) - def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool: + def instrument_source_for_line_profiler( + self, func_info: FunctionToOptimize, line_profiler_output_file: Path + ) -> bool: """Instrument source code for line profiling. Args: diff --git a/codeflash/languages/registry.py b/codeflash/languages/registry.py index 778171144..426e5ada2 100644 --- a/codeflash/languages/registry.py +++ b/codeflash/languages/registry.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from codeflash.languages.base import Language +from codeflash.languages.language_enum import Language if TYPE_CHECKING: from collections.abc import Iterable diff --git a/codeflash/languages/treesitter_utils.py b/codeflash/languages/treesitter_utils.py index 60ceea8e0..be45db401 100644 --- a/codeflash/languages/treesitter_utils.py +++ b/codeflash/languages/treesitter_utils.py @@ -464,6 +464,7 @@ def _walk_tree_for_imports( source_bytes: Source code bytes. imports: List to append found imports to. in_function: Whether we're currently inside a function/method body. + """ # Track when we enter function/method bodies # These node types contain function/method bodies where require() should not be treated as imports diff --git a/codeflash/models/models.py b/codeflash/models/models.py index ee6a92b79..1e8509c66 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -603,6 +603,9 @@ class FunctionParent: name: str type: str + def __str__(self) -> str: + return f"{self.type}:{self.name}" + class OriginalCodeBaseline(BaseModel): behavior_test_results: TestResults diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0ed8fbb67..2f2586e02 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -77,7 +77,7 @@ from codeflash.discovery.functions_to_optimize import was_function_previously_optimized from codeflash.either import Failure, Success, is_successful from codeflash.languages import is_python -from codeflash.languages.base import FunctionInfo, Language +from codeflash.languages.base import Language from codeflash.languages.current import current_language_support, is_typescript from codeflash.languages.javascript.module_system import detect_module_system from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown @@ -2172,10 +2172,10 @@ def process_review( else self.function_trace_id, "coverage_message": coverage_message, "replay_tests": replay_tests, - #"concolic_tests": concolic_tests, + # "concolic_tests": concolic_tests, "language": self.function_to_optimize.language, - #"original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""), - #"optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""), + # "original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""), + # "optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""), } raise_pr = not self.args.no_pr @@ -2842,18 +2842,8 @@ def line_profiler_step( # NOTE: currently this handles single file only, add support to multi file instrumentation (or should it be kept for the main file only) original_source = Path(self.function_to_optimize.file_path).read_text() # Instrument source code - func_info = FunctionInfo( - name=self.function_to_optimize.function_name, - file_path=self.function_to_optimize.file_path, - start_line=self.function_to_optimize.starting_line, - end_line=self.function_to_optimize.ending_line, - start_col=self.function_to_optimize.starting_col, - end_col=self.function_to_optimize.ending_col, - is_async=self.function_to_optimize.is_async, - language=self.language_support.language, - ) success = self.language_support.instrument_source_for_line_profiler( - func_info=func_info, line_profiler_output_file=line_profiler_output_path + func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path ) if not success: return {"timings": {}, "unit": 0, "str_out": ""} diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index c53f71cd5..c29637046 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -83,16 +83,12 @@ def generate_tests( # Instrument for behavior verification (writes to SQLite) instrumented_behavior_test_source = instrument_generated_js_test( - test_code=generated_test_source, - function_to_optimize=function_to_optimize, - mode=TestingMode.BEHAVIOR, + test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR ) # Instrument for performance measurement (prints to stdout) instrumented_perf_test_source = instrument_generated_js_test( - test_code=generated_test_source, - function_to_optimize=function_to_optimize, - mode=TestingMode.PERFORMANCE, + test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE ) logger.debug(f"Instrumented JS/TS tests locally for {function_to_optimize.function_name}") diff --git a/tests/test_languages/test_base.py b/tests/test_languages/test_base.py index dd8f86324..321e71388 100644 --- a/tests/test_languages/test_base.py +++ b/tests/test_languages/test_base.py @@ -87,138 +87,138 @@ def test_parent_info_hash(self): class TestFunctionInfo: - """Tests for the FunctionInfo dataclass.""" + """Tests for the FunctionInfo dataclass (alias for FunctionToOptimize).""" def test_function_info_creation_minimal(self): """Test creating FunctionInfo with minimal args.""" - func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3) - assert func.name == "add" + func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) + assert func.function_name == "add" assert func.file_path == Path("/test/example.py") - assert func.start_line == 1 - assert func.end_line == 3 - assert func.parents == () + assert func.starting_line == 1 + assert func.ending_line == 3 + assert func.parents == [] assert func.is_async is False assert func.is_method is False - assert func.language == Language.PYTHON + assert func.language == "python" def test_function_info_creation_full(self): """Test creating FunctionInfo with all args.""" - parents = (ParentInfo(name="Calculator", type="ClassDef"),) + parents = [ParentInfo(name="Calculator", type="ClassDef")] func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test/example.py"), - start_line=10, - end_line=15, + starting_line=10, + ending_line=15, parents=parents, is_async=True, is_method=True, - language=Language.PYTHON, - start_col=4, - end_col=20, + language="python", + starting_col=4, + ending_col=20, ) - assert func.name == "add" + assert func.function_name == "add" assert func.parents == parents assert func.is_async is True assert func.is_method is True - assert func.start_col == 4 - assert func.end_col == 20 + assert func.starting_col == 4 + assert func.ending_col == 20 def test_function_info_frozen(self): """Test that FunctionInfo is immutable.""" - func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3) + func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) with pytest.raises(AttributeError): - func.name = "new_name" + func.function_name = "new_name" def test_qualified_name_no_parents(self): """Test qualified_name without parents.""" - func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3) + func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) assert func.qualified_name == "add" def test_qualified_name_with_class(self): """Test qualified_name with class parent.""" func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test/example.py"), - start_line=1, - end_line=3, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=1, + ending_line=3, + parents=[ParentInfo(name="Calculator", type="ClassDef")], ) assert func.qualified_name == "Calculator.add" def test_qualified_name_nested(self): """Test qualified_name with nested parents.""" func = FunctionInfo( - name="inner", + function_name="inner", file_path=Path("/test/example.py"), - start_line=1, - end_line=3, - parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")), + starting_line=1, + ending_line=3, + parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")], ) assert func.qualified_name == "Outer.Inner.inner" def test_class_name_with_class(self): """Test class_name property with class parent.""" func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test/example.py"), - start_line=1, - end_line=3, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=1, + ending_line=3, + parents=[ParentInfo(name="Calculator", type="ClassDef")], ) assert func.class_name == "Calculator" def test_class_name_without_class(self): """Test class_name property without class parent.""" - func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3) + func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) assert func.class_name is None def test_class_name_nested_function(self): """Test class_name for function nested in another function.""" func = FunctionInfo( - name="inner", + function_name="inner", file_path=Path("/test/example.py"), - start_line=1, - end_line=3, - parents=(ParentInfo(name="outer", type="FunctionDef"),), + starting_line=1, + ending_line=3, + parents=[ParentInfo(name="outer", type="FunctionDef")], ) assert func.class_name is None def test_class_name_method_in_nested_class(self): """Test class_name for method in nested class.""" func = FunctionInfo( - name="method", + function_name="method", file_path=Path("/test/example.py"), - start_line=1, - end_line=3, - parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")), + starting_line=1, + ending_line=3, + parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")], ) # Should return the immediate parent class assert func.class_name == "Inner" def test_top_level_parent_name_no_parents(self): """Test top_level_parent_name without parents.""" - func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3) + func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) assert func.top_level_parent_name == "add" def test_top_level_parent_name_with_parents(self): """Test top_level_parent_name with parents.""" func = FunctionInfo( - name="method", + function_name="method", file_path=Path("/test/example.py"), - start_line=1, - end_line=3, - parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")), + starting_line=1, + ending_line=3, + parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")], ) assert func.top_level_parent_name == "Outer" def test_function_info_str(self): """Test string representation.""" func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test/example.py"), - start_line=1, - end_line=3, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=1, + ending_line=3, + parents=[ParentInfo(name="Calculator", type="ClassDef")], ) s = str(func) assert "Calculator.add" in s diff --git a/tests/test_languages/test_javascript_e2e.py b/tests/test_languages/test_javascript_e2e.py index a538e9249..2fe25c18a 100644 --- a/tests/test_languages/test_javascript_e2e.py +++ b/tests/test_languages/test_javascript_e2e.py @@ -172,7 +172,7 @@ def test_replace_function_in_javascript_file(self): js_support = get_language_support(Language.JAVASCRIPT) func_info = FunctionInfo( - name="add", file_path=Path("/tmp/test.js"), start_line=2, end_line=4, language=Language.JAVASCRIPT + function_name="add", file_path=Path("/tmp/test.js"), starting_line=2, ending_line=4, language="javascript" ) result = js_support.replace_function(original_source, func_info, new_function) @@ -216,7 +216,7 @@ def test_discover_jest_tests(self, js_project_dir): fib_file = js_project_dir / "fibonacci.js" func_info = FunctionInfo( - name="fibonacci", file_path=fib_file, start_line=11, end_line=16, language=Language.JAVASCRIPT + function_name="fibonacci", file_path=fib_file, starting_line=11, ending_line=16, language="javascript" ) tests = js_support.discover_tests(test_root, [func_info]) diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index 423504f66..9896d1d69 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -62,7 +62,7 @@ def test_line_profiler_instruments_simple_function(self): file_path = Path(f.name) func_info = FunctionInfo( - name="add", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT + function_name="add", file_path=file_path, starting_line=2, ending_line=5, language="javascript" ) output_file = Path("/tmp/test_profile.json") @@ -123,7 +123,7 @@ def test_tracer_instruments_simple_function(self): file_path = Path(f.name) func_info = FunctionInfo( - name="multiply", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT + function_name="multiply", file_path=file_path, starting_line=2, ending_line=4, language="javascript" ) output_db = Path("/tmp/test_traces.db") @@ -167,7 +167,7 @@ def test_javascript_support_instrument_for_behavior(self): file_path = Path(f.name) func_info = FunctionInfo( - name="greet", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT + function_name="greet", file_path=file_path, starting_line=2, ending_line=4, language="javascript" ) output_file = file_path.parent / ".codeflash" / "traces.db" @@ -198,7 +198,7 @@ def test_javascript_support_instrument_for_line_profiling(self): file_path = Path(f.name) func_info = FunctionInfo( - name="square", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT + function_name="square", file_path=file_path, starting_line=2, ending_line=5, language="javascript" ) output_file = file_path.parent / ".codeflash" / "line_profile.json" diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index 5f7f530c3..887e07b98 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -55,7 +55,7 @@ def test_discover_simple_function(self, js_support): functions = js_support.discover_functions(Path(f.name)) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" assert functions[0].language == Language.JAVASCRIPT def test_discover_multiple_functions(self, js_support): @@ -79,7 +79,7 @@ def test_discover_multiple_functions(self, js_support): functions = js_support.discover_functions(Path(f.name)) assert len(functions) == 3 - names = {func.name for func in functions} + names = {func.function_name for func in functions} assert names == {"add", "subtract", "multiply"} def test_discover_arrow_function(self, js_support): @@ -97,7 +97,7 @@ def test_discover_arrow_function(self, js_support): functions = js_support.discover_functions(Path(f.name)) assert len(functions) == 2 - names = {func.name for func in functions} + names = {func.function_name for func in functions} assert names == {"add", "multiply"} def test_discover_function_without_return_excluded(self, js_support): @@ -118,7 +118,7 @@ def test_discover_function_without_return_excluded(self, js_support): # Only the function with return should be discovered assert len(functions) == 1 - assert functions[0].name == "withReturn" + assert functions[0].function_name == "withReturn" def test_discover_class_methods(self, js_support): """Test discovering class methods.""" @@ -161,8 +161,8 @@ def test_discover_async_functions(self, js_support): assert len(functions) == 2 - async_func = next(f for f in functions if f.name == "fetchData") - sync_func = next(f for f in functions if f.name == "syncFunction") + async_func = next(f for f in functions if f.function_name == "fetchData") + sync_func = next(f for f in functions if f.function_name == "syncFunction") assert async_func.is_async is True assert sync_func.is_async is False @@ -185,7 +185,7 @@ def test_discover_with_filter_exclude_async(self, js_support): functions = js_support.discover_functions(Path(f.name), criteria) assert len(functions) == 1 - assert functions[0].name == "syncFunc" + assert functions[0].function_name == "syncFunc" def test_discover_with_filter_exclude_methods(self, js_support): """Test filtering out class methods.""" @@ -207,7 +207,7 @@ class MyClass { functions = js_support.discover_functions(Path(f.name), criteria) assert len(functions) == 1 - assert functions[0].name == "standalone" + assert functions[0].function_name == "standalone" def test_discover_line_numbers(self, js_support): """Test that line numbers are correctly captured.""" @@ -226,13 +226,13 @@ def test_discover_line_numbers(self, js_support): functions = js_support.discover_functions(Path(f.name)) - func1 = next(f for f in functions if f.name == "func1") - func2 = next(f for f in functions if f.name == "func2") + func1 = next(f for f in functions if f.function_name == "func1") + func2 = next(f for f in functions if f.function_name == "func2") - assert func1.start_line == 1 - assert func1.end_line == 3 - assert func2.start_line == 5 - assert func2.end_line == 9 + assert func1.starting_line == 1 + assert func1.ending_line == 3 + assert func2.starting_line == 5 + assert func2.ending_line == 9 def test_discover_generator_function(self, js_support): """Test discovering generator functions.""" @@ -249,7 +249,7 @@ def test_discover_generator_function(self, js_support): functions = js_support.discover_functions(Path(f.name)) assert len(functions) == 1 - assert functions[0].name == "numberGenerator" + assert functions[0].function_name == "numberGenerator" def test_discover_invalid_file_returns_empty(self, js_support): """Test that invalid JavaScript file returns empty list.""" @@ -280,7 +280,7 @@ def test_discover_function_expression(self, js_support): functions = js_support.discover_functions(Path(f.name)) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" def test_discover_immediately_invoked_function_excluded(self, js_support): """Test that IIFEs without names are excluded when require_name is True.""" @@ -300,7 +300,7 @@ def test_discover_immediately_invoked_function_excluded(self, js_support): # Only the named function should be discovered assert len(functions) == 1 - assert functions[0].name == "named" + assert functions[0].function_name == "named" class TestReplaceFunction: @@ -316,7 +316,7 @@ def test_replace_simple_function(self, js_support): return a * b; } """ - func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3) + func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) new_code = """function add(a, b) { // Optimized return (a + b) | 0; @@ -343,7 +343,7 @@ def test_replace_preserves_surrounding_code(self, js_support): // Footer """ - func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6) + func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6) new_code = """function target() { return 42; } @@ -365,11 +365,11 @@ def test_replace_with_indentation_adjustment(self, js_support): } """ func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test.js"), - start_line=2, - end_line=4, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=2, + ending_line=4, + parents=[ParentInfo(name="Calculator", type="ClassDef")], ) # New code has no indentation new_code = """add(a, b) { @@ -391,7 +391,7 @@ def test_replace_arrow_function(self, js_support): const multiply = (x, y) => x * y; """ - func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3) + func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) new_code = """const add = (a, b) => { return (a + b) | 0; }; @@ -483,7 +483,7 @@ def test_extract_simple_function(self, js_support): f.flush() file_path = Path(f.name) - func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=3) + func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=3) context = js_support.extract_code_context(func, file_path.parent, file_path.parent) @@ -508,7 +508,7 @@ def test_extract_with_helper(self, js_support): # First discover functions to get accurate line numbers functions = js_support.discover_functions(file_path) - main_func = next(f for f in functions if f.name == "main") + main_func = next(f for f in functions if f.function_name == "main") context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent) @@ -538,7 +538,7 @@ def test_discover_and_replace_workflow(self, js_support): functions = js_support.discover_functions(file_path) assert len(functions) == 1 func = functions[0] - assert func.name == "fibonacci" + assert func.function_name == "fibonacci" # Replace optimized_code = """function fibonacci(n) { @@ -626,7 +626,7 @@ def test_jsx_file(self, js_support): functions = js_support.discover_functions(file_path) # Should find both components - names = {f.name for f in functions} + names = {f.function_name for f in functions} assert "Button" in names assert "Card" in names @@ -688,7 +688,7 @@ def test_extract_class_method_wraps_in_class(self, js_support): # Discover the method functions = js_support.discover_functions(file_path) - add_method = next(f for f in functions if f.name == "add") + add_method = next(f for f in functions if f.function_name == "add") # Extract code context context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) @@ -725,7 +725,7 @@ class Calculator { file_path = Path(f.name) functions = js_support.discover_functions(file_path) - add_method = next(f for f in functions if f.name == "add") + add_method = next(f for f in functions if f.function_name == "add") context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) @@ -764,7 +764,7 @@ def test_extract_class_method_syntax_valid(self, js_support): file_path = Path(f.name) functions = js_support.discover_functions(file_path) - fib_method = next(f for f in functions if f.name == "fibonacci") + fib_method = next(f for f in functions if f.function_name == "fibonacci") context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent) @@ -802,7 +802,7 @@ def test_extract_nested_class_method(self, js_support): file_path = Path(f.name) functions = js_support.discover_functions(file_path) - add_method = next((f for f in functions if f.name == "add"), None) + add_method = next((f for f in functions if f.function_name == "add"), None) if add_method: context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) @@ -831,7 +831,7 @@ def test_extract_async_class_method(self, js_support): file_path = Path(f.name) functions = js_support.discover_functions(file_path) - fetch_method = next(f for f in functions if f.name == "fetchData") + fetch_method = next(f for f in functions if f.function_name == "fetchData") context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) @@ -863,7 +863,7 @@ def test_extract_static_class_method(self, js_support): file_path = Path(f.name) functions = js_support.discover_functions(file_path) - add_method = next((f for f in functions if f.name == "add"), None) + add_method = next((f for f in functions if f.function_name == "add"), None) if add_method: context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) @@ -891,7 +891,7 @@ def test_extract_class_method_without_class_jsdoc(self, js_support): file_path = Path(f.name) functions = js_support.discover_functions(file_path) - method = next(f for f in functions if f.name == "simpleMethod") + method = next(f for f in functions if f.function_name == "simpleMethod") context = js_support.extract_code_context(method, file_path.parent, file_path.parent) @@ -922,11 +922,11 @@ def test_replace_class_method_preserves_class_structure(self, js_support): } """ func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test.js"), - start_line=2, - end_line=4, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=2, + ending_line=4, + parents=[ParentInfo(name="Calculator", type="ClassDef")], is_method=True, ) new_code = """ add(a, b) { @@ -963,12 +963,12 @@ def test_replace_class_method_with_jsdoc(self, js_support): } """ func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test.js"), - start_line=5, # Method starts here - end_line=7, + starting_line=5, # Method starts here + ending_line=7, doc_start_line=2, # JSDoc starts here - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + parents=[ParentInfo(name="Calculator", type="ClassDef")], is_method=True, ) new_code = """ /** @@ -1000,11 +1000,11 @@ def test_replace_multiple_class_methods_sequentially(self, js_support): """ # Replace add first add_func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test.js"), - start_line=2, - end_line=4, - parents=(ParentInfo(name="Math", type="ClassDef"),), + starting_line=2, + ending_line=4, + parents=[ParentInfo(name="Math", type="ClassDef")], is_method=True, ) source = js_support.replace_function( @@ -1032,11 +1032,11 @@ def test_replace_class_method_indentation_adjustment(self, js_support): } """ func = FunctionInfo( - name="innerMethod", + function_name="innerMethod", file_path=Path("/test.js"), - start_line=2, - end_line=4, - parents=(ParentInfo(name="Indented", type="ClassDef"),), + starting_line=2, + ending_line=4, + parents=[ParentInfo(name="Indented", type="ClassDef")], is_method=True, ) # New code with no indentation @@ -1077,7 +1077,7 @@ def test_class_with_constructor(self, js_support): functions = js_support.discover_functions(file_path) # Should find constructor and increment - names = {f.name for f in functions} + names = {f.function_name for f in functions} assert "constructor" in names or "increment" in names def test_class_with_getters_setters(self, js_support): @@ -1107,7 +1107,7 @@ def test_class_with_getters_setters(self, js_support): functions = js_support.discover_functions(file_path) # Should find at least greet - names = {f.name for f in functions} + names = {f.function_name for f in functions} assert "greet" in names def test_class_extending_another(self, js_support): @@ -1135,7 +1135,7 @@ class Dog extends Animal { functions = js_support.discover_functions(file_path) # Find Dog's fetch method - fetch_method = next((f for f in functions if f.name == "fetch" and f.class_name == "Dog"), None) + fetch_method = next((f for f in functions if f.function_name == "fetch" and f.class_name == "Dog"), None) if fetch_method: context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) @@ -1169,7 +1169,7 @@ def test_class_with_private_method(self, js_support): functions = js_support.discover_functions(file_path) # Should at least find publicMethod - names = {f.name for f in functions} + names = {f.function_name for f in functions} assert "publicMethod" in names def test_commonjs_class_export(self, js_support): @@ -1187,7 +1187,7 @@ def test_commonjs_class_export(self, js_support): file_path = Path(f.name) functions = js_support.discover_functions(file_path) - add_method = next(f for f in functions if f.name == "add") + add_method = next(f for f in functions if f.function_name == "add") context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) @@ -1209,7 +1209,7 @@ def test_es_module_class_export(self, js_support): functions = js_support.discover_functions(file_path) # Find the add method - add_method = next((f for f in functions if f.name == "add"), None) + add_method = next((f for f in functions if f.function_name == "add"), None) if add_method: context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) @@ -1260,7 +1260,7 @@ class Counter { file_path = Path(f.name) functions = js_support.discover_functions(file_path) - increment_func = next(fn for fn in functions if fn.name == "increment") + increment_func = next(fn for fn in functions if fn.function_name == "increment") # Step 1: Extract code context (includes constructor for AI context) context = js_support.extract_code_context(increment_func, file_path.parent, file_path.parent) @@ -1359,7 +1359,7 @@ class User { file_path = Path(f.name) functions = ts_support.discover_functions(file_path) - get_name_func = next(fn for fn in functions if fn.name == "getName") + get_name_func = next(fn for fn in functions if fn.function_name == "getName") # Step 1: Extract code context (includes fields and constructor) context = ts_support.extract_code_context(get_name_func, file_path.parent, file_path.parent) @@ -1461,7 +1461,7 @@ class Calculator { file_path = Path(f.name) functions = js_support.discover_functions(file_path) - add_func = next(fn for fn in functions if fn.name == "add") + add_func = next(fn for fn in functions if fn.function_name == "add") # Extract context for add context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent) @@ -1547,7 +1547,7 @@ class MathUtils { file_path = Path(f.name) functions = js_support.discover_functions(file_path) - add_func = next(fn for fn in functions if fn.name == "add") + add_func = next(fn for fn in functions if fn.function_name == "add") # Extract context context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent) diff --git a/tests/test_languages/test_language_parity.py b/tests/test_languages/test_language_parity.py index 639f4f1c0..ae57eb426 100644 --- a/tests/test_languages/test_language_parity.py +++ b/tests/test_languages/test_language_parity.py @@ -353,8 +353,8 @@ def test_simple_function_discovery(self, python_support, js_support): assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1" # Both should find 'add' - assert py_funcs[0].name == "add" - assert js_funcs[0].name == "add" + assert py_funcs[0].function_name == "add" + assert js_funcs[0].function_name == "add" # Both should have correct language assert py_funcs[0].language == Language.PYTHON @@ -373,8 +373,8 @@ def test_multiple_functions_discovery(self, python_support, js_support): assert len(js_funcs) == 3, f"JavaScript found {len(js_funcs)}, expected 3" # Both should find the same function names - py_names = {f.name for f in py_funcs} - js_names = {f.name for f in js_funcs} + py_names = {f.function_name for f in py_funcs} + js_names = {f.function_name for f in js_funcs} assert py_names == {"add", "subtract", "multiply"} assert js_names == {"add", "subtract", "multiply"} @@ -392,8 +392,8 @@ def test_functions_without_return_excluded(self, python_support, js_support): assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1" # The function with return should be found - assert py_funcs[0].name == "with_return" - assert js_funcs[0].name == "withReturn" + assert py_funcs[0].function_name == "with_return" + assert js_funcs[0].function_name == "withReturn" def test_class_methods_discovery(self, python_support, js_support): """Both should discover class methods with proper metadata.""" @@ -409,12 +409,12 @@ def test_class_methods_discovery(self, python_support, js_support): # All should be marked as methods for func in py_funcs: - assert func.is_method is True, f"Python {func.name} should be a method" - assert func.class_name == "Calculator", f"Python {func.name} should belong to Calculator" + assert func.is_method is True, f"Python {func.function_name} should be a method" + assert func.class_name == "Calculator", f"Python {func.function_name} should belong to Calculator" for func in js_funcs: - assert func.is_method is True, f"JavaScript {func.name} should be a method" - assert func.class_name == "Calculator", f"JavaScript {func.name} should belong to Calculator" + assert func.is_method is True, f"JavaScript {func.function_name} should be a method" + assert func.class_name == "Calculator", f"JavaScript {func.function_name} should belong to Calculator" def test_async_functions_discovery(self, python_support, js_support): """Both should correctly identify async functions.""" @@ -429,10 +429,10 @@ def test_async_functions_discovery(self, python_support, js_support): assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2" # Check async flags - py_async = next(f for f in py_funcs if "fetch" in f.name.lower()) - py_sync = next(f for f in py_funcs if "sync" in f.name.lower()) - js_async = next(f for f in js_funcs if "fetch" in f.name.lower()) - js_sync = next(f for f in js_funcs if "sync" in f.name.lower()) + py_async = next(f for f in py_funcs if "fetch" in f.function_name.lower()) + py_sync = next(f for f in py_funcs if "sync" in f.function_name.lower()) + js_async = next(f for f in js_funcs if "fetch" in f.function_name.lower()) + js_sync = next(f for f in js_funcs if "sync" in f.function_name.lower()) assert py_async.is_async is True, "Python async function should have is_async=True" assert py_sync.is_async is False, "Python sync function should have is_async=False" @@ -452,15 +452,15 @@ def test_nested_functions_discovery(self, python_support, js_support): assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2" # Check names - py_names = {f.name for f in py_funcs} - js_names = {f.name for f in js_funcs} + py_names = {f.function_name for f in py_funcs} + js_names = {f.function_name for f in js_funcs} assert py_names == {"outer", "inner"}, f"Python found {py_names}" assert js_names == {"outer", "inner"}, f"JavaScript found {js_names}" # Check parent info for inner function - py_inner = next(f for f in py_funcs if f.name == "inner") - js_inner = next(f for f in js_funcs if f.name == "inner") + py_inner = next(f for f in py_funcs if f.function_name == "inner") + js_inner = next(f for f in js_funcs if f.function_name == "inner") assert len(py_inner.parents) >= 1, "Python inner should have parent info" assert py_inner.parents[0].name == "outer", "Python inner's parent should be outer" @@ -482,8 +482,8 @@ def test_static_methods_discovery(self, python_support, js_support): assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1" # Both should find 'helper' belonging to 'Utils' - assert py_funcs[0].name == "helper" - assert js_funcs[0].name == "helper" + assert py_funcs[0].function_name == "helper" + assert js_funcs[0].function_name == "helper" assert py_funcs[0].class_name == "Utils" assert js_funcs[0].class_name == "Utils" @@ -532,8 +532,8 @@ def test_filter_exclude_async(self, python_support, js_support): assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1" # Should be the sync function - assert "sync" in py_funcs[0].name.lower() - assert "sync" in js_funcs[0].name.lower() + assert "sync" in py_funcs[0].function_name.lower() + assert "sync" in js_funcs[0].function_name.lower() def test_filter_exclude_methods(self, python_support, js_support): """Both should support filtering out class methods.""" @@ -550,8 +550,8 @@ def test_filter_exclude_methods(self, python_support, js_support): assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1" # Should be the standalone function - assert py_funcs[0].name == "standalone" - assert js_funcs[0].name == "standalone" + assert py_funcs[0].function_name == "standalone" + assert js_funcs[0].function_name == "standalone" def test_nonexistent_file_returns_empty(self, python_support, js_support): """Both should return empty list for nonexistent files.""" @@ -570,14 +570,14 @@ def test_line_numbers_captured(self, python_support, js_support): js_funcs = js_support.discover_functions(js_file) # Both should have start_line and end_line - assert py_funcs[0].start_line is not None - assert py_funcs[0].end_line is not None - assert js_funcs[0].start_line is not None - assert js_funcs[0].end_line is not None + assert py_funcs[0].starting_line is not None + assert py_funcs[0].ending_line is not None + assert js_funcs[0].starting_line is not None + assert js_funcs[0].ending_line is not None # Start should be before or equal to end - assert py_funcs[0].start_line <= py_funcs[0].end_line - assert js_funcs[0].start_line <= js_funcs[0].end_line + assert py_funcs[0].starting_line <= py_funcs[0].ending_line + assert js_funcs[0].starting_line <= js_funcs[0].ending_line # ============================================================================ @@ -604,8 +604,8 @@ def multiply(a, b): return a * b; } """ - py_func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2) - js_func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3) + py_func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2) + js_func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) py_new = """def add(a, b): return (a + b) | 0 @@ -651,8 +651,8 @@ def other(): // Footer """ - py_func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5) - js_func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6) + py_func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5) + js_func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6) py_new = """def target(): return 42 @@ -693,18 +693,18 @@ def add(self, a, b): } """ py_func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test.py"), - start_line=2, - end_line=3, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=2, + ending_line=3, + parents=[ParentInfo(name="Calculator", type="ClassDef")], ) js_func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test.js"), - start_line=2, - end_line=4, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=2, + ending_line=4, + parents=[ParentInfo(name="Calculator", type="ClassDef")], ) # New code without indentation @@ -872,8 +872,8 @@ def test_simple_function_context(self, python_support, js_support): ".js", ) - py_func = FunctionInfo(name="add", file_path=py_file, start_line=1, end_line=2) - js_func = FunctionInfo(name="add", file_path=js_file, start_line=1, end_line=3) + py_func = FunctionInfo(function_name="add", file_path=py_file, starting_line=1, ending_line=2) + js_func = FunctionInfo(function_name="add", file_path=js_file, starting_line=1, ending_line=3) py_context = python_support.extract_code_context(py_func, py_file.parent, py_file.parent) js_context = js_support.extract_code_context(js_func, js_file.parent, js_file.parent) @@ -922,8 +922,8 @@ def test_discover_and_replace_workflow(self, python_support, js_support): assert len(py_funcs) == 1 assert len(js_funcs) == 1 - assert py_funcs[0].name == "fibonacci" - assert js_funcs[0].name == "fibonacci" + assert py_funcs[0].function_name == "fibonacci" + assert js_funcs[0].function_name == "fibonacci" # Replace py_optimized = """def fibonacci(n): @@ -974,20 +974,20 @@ def test_function_info_fields_populated(self, python_support, js_support): for py_func in py_funcs: # Check all expected fields are populated - assert py_func.name is not None, "Python: name should be populated" + assert py_func.function_name is not None, "Python: name should be populated" assert py_func.file_path is not None, "Python: file_path should be populated" - assert py_func.start_line is not None, "Python: start_line should be populated" - assert py_func.end_line is not None, "Python: end_line should be populated" + assert py_func.starting_line is not None, "Python: start_line should be populated" + assert py_func.ending_line is not None, "Python: end_line should be populated" assert py_func.language is not None, "Python: language should be populated" # is_method and class_name should be set for class methods assert py_func.is_method is not None, "Python: is_method should be populated" for js_func in js_funcs: # JavaScript should populate the same fields - assert js_func.name is not None, "JavaScript: name should be populated" + assert js_func.function_name is not None, "JavaScript: name should be populated" assert js_func.file_path is not None, "JavaScript: file_path should be populated" - assert js_func.start_line is not None, "JavaScript: start_line should be populated" - assert js_func.end_line is not None, "JavaScript: end_line should be populated" + assert js_func.starting_line is not None, "JavaScript: start_line should be populated" + assert js_func.ending_line is not None, "JavaScript: end_line should be populated" assert js_func.language is not None, "JavaScript: language should be populated" assert js_func.is_method is not None, "JavaScript: is_method should be populated" @@ -1006,7 +1006,7 @@ def test_arrow_functions_unique_to_js(self, js_support): funcs = js_support.discover_functions(js_file) # Should find all arrow functions - names = {f.name for f in funcs} + names = {f.function_name for f in funcs} assert "add" in names, "Should find arrow function 'add'" assert "multiply" in names, "Should find concise arrow function 'multiply'" # identity might or might not be found depending on implicit return handling @@ -1057,7 +1057,7 @@ def multi_decorated(): funcs = python_support.discover_functions(py_file) # Should find all functions regardless of decorators - names = {f.name for f in funcs} + names = {f.function_name for f in funcs} assert "decorated" in names assert "decorated_with_args" in names assert "multi_decorated" in names @@ -1077,7 +1077,7 @@ def test_function_expressions_js(self, js_support): funcs = js_support.discover_functions(js_file) # Should find function expressions - names = {f.name for f in funcs} + names = {f.function_name for f in funcs} assert "add" in names, "Should find anonymous function expression assigned to 'add'" @@ -1144,5 +1144,5 @@ def greeting(): assert len(py_funcs) == 1 assert len(js_funcs) == 1 - assert py_funcs[0].name == "greeting" - assert js_funcs[0].name == "greeting" + assert py_funcs[0].function_name == "greeting" + assert js_funcs[0].function_name == "greeting" diff --git a/tests/test_languages/test_python_support.py b/tests/test_languages/test_python_support.py index ea8c1a0de..c7be580ac 100644 --- a/tests/test_languages/test_python_support.py +++ b/tests/test_languages/test_python_support.py @@ -52,7 +52,7 @@ def add(a, b): functions = python_support.discover_functions(Path(f.name)) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" assert functions[0].language == Language.PYTHON def test_discover_multiple_functions(self, python_support): @@ -73,7 +73,7 @@ def multiply(a, b): functions = python_support.discover_functions(Path(f.name)) assert len(functions) == 3 - names = {func.name for func in functions} + names = {func.function_name for func in functions} assert names == {"add", "subtract", "multiply"} def test_discover_function_with_no_return_excluded(self, python_support): @@ -92,7 +92,7 @@ def without_return(): # Only the function with return should be discovered assert len(functions) == 1 - assert functions[0].name == "with_return" + assert functions[0].function_name == "with_return" def test_discover_class_methods(self, python_support): """Test discovering class methods.""" @@ -130,8 +130,8 @@ def sync_function(): assert len(functions) == 2 - async_func = next(f for f in functions if f.name == "fetch_data") - sync_func = next(f for f in functions if f.name == "sync_function") + async_func = next(f for f in functions if f.function_name == "fetch_data") + sync_func = next(f for f in functions if f.function_name == "sync_function") assert async_func.is_async is True assert sync_func.is_async is False @@ -151,11 +151,11 @@ def inner(): # Both outer and inner should be discovered assert len(functions) == 2 - names = {func.name for func in functions} + names = {func.function_name for func in functions} assert names == {"outer", "inner"} # Inner should have outer as parent - inner = next(f for f in functions if f.name == "inner") + inner = next(f for f in functions if f.function_name == "inner") assert len(inner.parents) == 1 assert inner.parents[0].name == "outer" assert inner.parents[0].type == "FunctionDef" @@ -174,7 +174,7 @@ def helper(x): functions = python_support.discover_functions(Path(f.name)) assert len(functions) == 1 - assert functions[0].name == "helper" + assert functions[0].function_name == "helper" assert functions[0].class_name == "Utils" def test_discover_with_filter_exclude_async(self, python_support): @@ -193,7 +193,7 @@ def sync_func(): functions = python_support.discover_functions(Path(f.name), criteria) assert len(functions) == 1 - assert functions[0].name == "sync_func" + assert functions[0].function_name == "sync_func" def test_discover_with_filter_exclude_methods(self, python_support): """Test filtering out class methods.""" @@ -212,7 +212,7 @@ def method(self): functions = python_support.discover_functions(Path(f.name), criteria) assert len(functions) == 1 - assert functions[0].name == "standalone" + assert functions[0].function_name == "standalone" def test_discover_line_numbers(self, python_support): """Test that line numbers are correctly captured.""" @@ -229,13 +229,13 @@ def func2(): functions = python_support.discover_functions(Path(f.name)) - func1 = next(f for f in functions if f.name == "func1") - func2 = next(f for f in functions if f.name == "func2") + func1 = next(f for f in functions if f.function_name == "func1") + func2 = next(f for f in functions if f.function_name == "func2") - assert func1.start_line == 1 - assert func1.end_line == 2 - assert func2.start_line == 4 - assert func2.end_line == 7 + assert func1.starting_line == 1 + assert func1.ending_line == 2 + assert func2.starting_line == 4 + assert func2.ending_line == 7 def test_discover_invalid_file_returns_empty(self, python_support): """Test that invalid Python file returns empty list.""" @@ -263,7 +263,7 @@ def test_replace_simple_function(self, python_support): def multiply(a, b): return a * b """ - func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2) + func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2) new_code = """def add(a, b): # Optimized return (a + b) | 0 @@ -287,7 +287,7 @@ def other(): # Footer """ - func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5) + func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5) new_code = """def target(): return 42 """ @@ -306,11 +306,11 @@ def add(self, a, b): return a + b """ func = FunctionInfo( - name="add", + function_name="add", file_path=Path("/test.py"), - start_line=2, - end_line=3, - parents=(ParentInfo(name="Calculator", type="ClassDef"),), + starting_line=2, + ending_line=3, + parents=[ParentInfo(name="Calculator", type="ClassDef")], ) # New code has no indentation new_code = """def add(self, a, b): @@ -331,7 +331,7 @@ def test_replace_first_function(self, python_support): def second(): return 2 """ - func = FunctionInfo(name="first", file_path=Path("/test.py"), start_line=1, end_line=2) + func = FunctionInfo(function_name="first", file_path=Path("/test.py"), starting_line=1, ending_line=2) new_code = """def first(): return 100 """ @@ -348,7 +348,7 @@ def test_replace_last_function(self, python_support): def last(): return 999 """ - func = FunctionInfo(name="last", file_path=Path("/test.py"), start_line=4, end_line=5) + func = FunctionInfo(function_name="last", file_path=Path("/test.py"), starting_line=4, ending_line=5) new_code = """def last(): return 1000 """ @@ -362,7 +362,7 @@ def test_replace_only_function(self, python_support): source = """def only(): return 42 """ - func = FunctionInfo(name="only", file_path=Path("/test.py"), start_line=1, end_line=2) + func = FunctionInfo(function_name="only", file_path=Path("/test.py"), starting_line=1, ending_line=2) new_code = """def only(): return 100 """ @@ -474,7 +474,7 @@ def test_extract_simple_function(self, python_support): f.flush() file_path = Path(f.name) - func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=2) + func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=2) context = python_support.extract_code_context(func, file_path.parent, file_path.parent) @@ -503,7 +503,7 @@ def test_discover_and_replace_workflow(self, python_support): functions = python_support.discover_functions(file_path) assert len(functions) == 1 func = functions[0] - assert func.name == "fibonacci" + assert func.function_name == "fibonacci" # Replace optimized_code = """def fibonacci(n): diff --git a/tests/test_languages/test_treesitter_utils.py b/tests/test_languages/test_treesitter_utils.py index e5e776a11..a557a84dc 100644 --- a/tests/test_languages/test_treesitter_utils.py +++ b/tests/test_languages/test_treesitter_utils.py @@ -136,7 +136,7 @@ def test_find_function_declaration(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" assert functions[0].is_arrow is False assert functions[0].is_async is False assert functions[0].is_method is False @@ -151,7 +151,7 @@ def test_find_arrow_function(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" assert functions[0].is_arrow is True def test_find_arrow_function_concise(self, js_analyzer): @@ -160,7 +160,7 @@ def test_find_arrow_function_concise(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "double" + assert functions[0].function_name == "double" assert functions[0].is_arrow is True def test_find_async_function(self, js_analyzer): @@ -173,7 +173,7 @@ def test_find_async_function(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "fetchData" + assert functions[0].function_name == "fetchData" assert functions[0].is_async is True def test_find_class_methods(self, js_analyzer): @@ -188,7 +188,7 @@ class Calculator { functions = js_analyzer.find_functions(code, include_methods=True) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" assert functions[0].is_method is True assert functions[0].class_name == "Calculator" @@ -208,7 +208,7 @@ class Calculator { functions = js_analyzer.find_functions(code, include_methods=False) assert len(functions) == 1 - assert functions[0].name == "standalone" + assert functions[0].function_name == "standalone" def test_exclude_arrow_functions(self, js_analyzer): """Test excluding arrow functions.""" @@ -222,7 +222,7 @@ def test_exclude_arrow_functions(self, js_analyzer): functions = js_analyzer.find_functions(code, include_arrow_functions=False) assert len(functions) == 1 - assert functions[0].name == "regular" + assert functions[0].function_name == "regular" def test_find_generator_function(self, js_analyzer): """Test finding generator functions.""" @@ -235,7 +235,7 @@ def test_find_generator_function(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "numberGenerator" + assert functions[0].function_name == "numberGenerator" assert functions[0].is_generator is True def test_function_line_numbers(self, js_analyzer): @@ -291,7 +291,7 @@ def test_require_name_filters_anonymous(self, js_analyzer): functions = js_analyzer.find_functions(code, require_name=True) assert len(functions) == 1 - assert functions[0].name == "named" + assert functions[0].function_name == "named" def test_function_expression_in_variable(self, js_analyzer): """Test function expression assigned to variable.""" @@ -303,7 +303,7 @@ def test_function_expression_in_variable(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" class TestFindImports: @@ -515,7 +515,7 @@ def test_find_typed_function(self, ts_analyzer): functions = ts_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" def test_find_interface_method(self, ts_analyzer): """Test that interface methods are not found (they're declarations).""" @@ -544,4 +544,4 @@ def test_find_generic_function(self, ts_analyzer): functions = ts_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].name == "identity" + assert functions[0].function_name == "identity" diff --git a/tests/test_languages/test_typescript_code_extraction.py b/tests/test_languages/test_typescript_code_extraction.py index 32bf77215..f97049943 100644 --- a/tests/test_languages/test_typescript_code_extraction.py +++ b/tests/test_languages/test_typescript_code_extraction.py @@ -128,7 +128,7 @@ def test_extract_simple_function(self, ts_support): functions = ts_support.discover_functions(file_path) assert len(functions) == 1 - assert functions[0].name == "add" + assert functions[0].function_name == "add" # Extract code context code_context = ts_support.extract_code_context( @@ -166,7 +166,7 @@ def test_extract_async_function_with_template_literal(self, ts_support): functions = ts_support.discover_functions(file_path) assert len(functions) == 1 - assert functions[0].name == "execMongoEval" + assert functions[0].function_name == "execMongoEval" # Extract code context code_context = ts_support.extract_code_context( @@ -217,7 +217,7 @@ def test_extract_function_with_complex_try_catch(self, ts_support): functions = ts_support.discover_functions(file_path) assert len(functions) == 1 - assert functions[0].name == "figureOutContentsPath" + assert functions[0].function_name == "figureOutContentsPath" # Extract code context code_context = ts_support.extract_code_context( diff --git a/tests/test_languages/test_vitest_e2e.py b/tests/test_languages/test_vitest_e2e.py index b29b90f73..68448c1cf 100644 --- a/tests/test_languages/test_vitest_e2e.py +++ b/tests/test_languages/test_vitest_e2e.py @@ -177,7 +177,7 @@ def test_discover_vitest_tests(self, vitest_project_dir): fib_file = vitest_project_dir / "fibonacci.ts" func_info = FunctionInfo( - name="fibonacci", file_path=fib_file, start_line=11, end_line=16, language=Language.TYPESCRIPT + function_name="fibonacci", file_path=fib_file, starting_line=11, ending_line=16, language="typescript" ) tests = ts_support.discover_tests(test_root, [func_info]) From a5edb73b133eb1af3ca5058ea945741307c86e73 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Mon, 2 Feb 2026 11:32:39 -0800 Subject: [PATCH 3/5] fix: Use FunctionToOptimize field names consistently across JS code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix field name mismatches: .name → .function_name, .start_line → .starting_line, .end_line → .ending_line, .start_col → .starting_col, .end_col → .ending_col - Fix circular imports by creating function_types.py with FunctionParent - Add lazy language registration via _ensure_languages_registered() - Fix macOS symlink path resolution in ImportResolver - Update all affected code and tests to use correct FunctionToOptimize attributes Co-Authored-By: Claude Opus 4.5 --- codeflash/cli_cmds/init_javascript.py | 15 ++-- codeflash/code_utils/code_replacer.py | 14 ++-- codeflash/discovery/functions_to_optimize.py | 21 +---- codeflash/languages/__init__.py | 40 ++++++---- codeflash/languages/base.py | 5 +- .../languages/javascript/find_references.py | 11 +-- .../languages/javascript/import_resolver.py | 7 +- codeflash/languages/javascript/instrument.py | 3 +- .../languages/javascript/line_profiler.py | 12 +-- codeflash/languages/javascript/support.py | 2 +- codeflash/languages/javascript/test_runner.py | 17 +--- codeflash/languages/javascript/tracer.py | 10 +-- .../languages/javascript/vitest_runner.py | 9 +-- codeflash/languages/python/support.py | 12 +-- codeflash/languages/registry.py | 38 +++++++++ codeflash/models/function_types.py | 18 +++++ codeflash/models/models.py | 9 +-- .../test_code_context_extraction.py | 38 ++++----- .../test_javascript_test_discovery.py | 2 +- .../test_languages/test_js_code_extractor.py | 79 ++++++++++--------- tests/test_languages/test_js_code_replacer.py | 24 +++--- .../test_multi_file_code_replacer.py | 13 +-- tests/test_languages/test_treesitter_utils.py | 24 +++--- 23 files changed, 224 insertions(+), 199 deletions(-) create mode 100644 codeflash/models/function_types.py diff --git a/codeflash/cli_cmds/init_javascript.py b/codeflash/cli_cmds/init_javascript.py index 05159f65e..3e64bccae 100644 --- a/codeflash/cli_cmds/init_javascript.py +++ b/codeflash/cli_cmds/init_javascript.py @@ -155,22 +155,21 @@ def get_package_install_command(project_root: Path, package: str, dev: bool = Tr if dev: cmd.append("--save-dev") return cmd - elif pkg_manager == JsPackageManager.YARN: + if pkg_manager == JsPackageManager.YARN: cmd = ["yarn", "add", package] if dev: cmd.append("--dev") return cmd - elif pkg_manager == JsPackageManager.BUN: + if pkg_manager == JsPackageManager.BUN: cmd = ["bun", "add", package] if dev: cmd.append("--dev") return cmd - else: - # Default to npm - cmd = ["npm", "install", package] - if dev: - cmd.append("--save-dev") - return cmd + # Default to npm + cmd = ["npm", "install", package] + if dev: + cmd.append("--save-dev") + return cmd def init_js_project(language: ProjectLanguage) -> None: diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index b979dc37e..942b912e8 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -549,7 +549,7 @@ def replace_function_definitions_for_language( # Find the function in current code func = None for f in current_functions: - if func_name in (f.qualified_name, f.name): + if func_name in (f.qualified_name, f.function_name): func = f break @@ -557,7 +557,9 @@ def replace_function_definitions_for_language( continue # Extract just this function from the optimized code - optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath) + optimized_func = _extract_function_from_code( + lang_support, code_to_apply, func.function_name, module_abspath + ) if optimized_func: new_code = lang_support.replace_function(new_code, func, optimized_func) modified = True @@ -596,13 +598,13 @@ def _extract_function_from_code( # file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript) functions = lang_support.discover_functions_from_source(source_code, file_path) for func in functions: - if func.name == function_name: + if func.function_name == function_name: # Extract the function's source using line numbers # Use doc_start_line if available to include JSDoc/docstring lines = source_code.splitlines(keepends=True) - effective_start = func.doc_start_line or func.start_line - if effective_start and func.end_line and effective_start <= len(lines): - func_lines = lines[effective_start - 1 : func.end_line] + effective_start = func.doc_start_line or func.starting_line + if effective_start and func.ending_line and effective_start <= len(lines): + func_lines = lines[effective_start - 1 : func.ending_line] return "".join(func_lines) except Exception as e: logger.debug(f"Error extracting function {function_name}: {e}") diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 6db677ca7..b7d6b5d4f 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -335,25 +335,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list try: lang_support = get_language_support(file_path) criteria = FunctionFilterCriteria(require_return=True) - function_infos = lang_support.discover_functions(file_path, criteria) - - ftos = [] - for func_info in function_infos: - parents = [FunctionParent(p.name, p.type) for p in func_info.parents] - ftos.append( - FunctionToOptimize( - function_name=func_info.name, - file_path=func_info.file_path, - parents=parents, - starting_line=func_info.start_line, - ending_line=func_info.end_line, - starting_col=func_info.start_col, - ending_col=func_info.end_col, - is_async=func_info.is_async, - language=func_info.language.value, - ) - ) - functions[file_path] = ftos + # discover_functions already returns FunctionToOptimize objects + functions[file_path] = lang_support.discover_functions(file_path, criteria) except Exception as e: logger.debug(f"Failed to discover functions in {file_path}: {e}") diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index 0c4e9d07c..47136f4e7 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -26,17 +26,6 @@ TestInfo, TestResult, ) - - -# Lazy import for FunctionInfo to avoid circular imports -def __getattr__(name: str): - if name == "FunctionInfo": - from codeflash.discovery.functions_to_optimize import FunctionToOptimize - - return FunctionToOptimize - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - from codeflash.languages.current import ( current_language, current_language_support, @@ -46,11 +35,9 @@ def __getattr__(name: str): reset_current_language, set_current_language, ) -from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401 -# Import language support modules to trigger auto-registration -# This ensures all supported languages are available when this package is imported -from codeflash.languages.python import PythonSupport # noqa: F401 +# Language support modules are imported lazily to avoid circular imports +# They get registered when first accessed via get_language_support() from codeflash.languages.registry import ( detect_project_language, get_language_support, @@ -70,6 +57,29 @@ def __getattr__(name: str): set_current_test_framework, ) + +# Lazy imports to avoid circular imports +def __getattr__(name: str): + if name == "FunctionInfo": + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + + return FunctionToOptimize + if name == "JavaScriptSupport": + from codeflash.languages.javascript.support import JavaScriptSupport + + return JavaScriptSupport + if name == "TypeScriptSupport": + from codeflash.languages.javascript.support import TypeScriptSupport + + return TypeScriptSupport + if name == "PythonSupport": + from codeflash.languages.python.support import PythonSupport + + return PythonSupport + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) + + __all__ = [ "CodeContext", "FunctionInfo", diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 5fb7f99ce..99cefdf46 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -17,7 +17,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.language_enum import Language -from codeflash.models.models import FunctionParent +from codeflash.models.function_types import FunctionParent # Backward compatibility aliases - ParentInfo is now FunctionParent ParentInfo = FunctionParent @@ -30,7 +30,8 @@ def __getattr__(name: str) -> Any: from codeflash.discovery.functions_to_optimize import FunctionToOptimize return FunctionToOptimize - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) @dataclass diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py index 3cf0d8b30..812f7c4a7 100644 --- a/codeflash/languages/javascript/find_references.py +++ b/codeflash/languages/javascript/find_references.py @@ -15,16 +15,16 @@ import logging from dataclasses import dataclass, field -from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from pathlib import Path + from tree_sitter import Node + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer -from codeflash.discovery.functions_to_optimize import FunctionToOptimize - logger = logging.getLogger(__name__) @@ -781,10 +781,7 @@ def _should_exclude(self, file_path: Path) -> bool: """ path_str = str(file_path) - for pattern in self.exclude_patterns: - if pattern in path_str: - return True - return False + return any(pattern in path_str for pattern in self.exclude_patterns) def _read_file(self, file_path: Path) -> str | None: """Read a file's contents with caching. diff --git a/codeflash/languages/javascript/import_resolver.py b/codeflash/languages/javascript/import_resolver.py index f4def166f..6d65432f1 100644 --- a/codeflash/languages/javascript/import_resolver.py +++ b/codeflash/languages/javascript/import_resolver.py @@ -44,7 +44,8 @@ def __init__(self, project_root: Path) -> None: project_root: Root directory of the project. """ - self.project_root = project_root + # Resolve to real path to handle macOS symlinks like /var -> /private/var + self.project_root = project_root.resolve() self._resolution_cache: dict[tuple[Path, str], Path | None] = {} def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None: @@ -329,7 +330,7 @@ def find_helpers( all_functions = analyzer.find_functions(source, include_methods=True) target_func = None for func in all_functions: - if func.name == function.name and func.start_line == function.start_line: + if func.name == function.function_name and func.start_line == function.starting_line: target_func = func break @@ -506,7 +507,7 @@ def _find_helpers_recursive( Dictionary mapping file paths to lists of helper functions. """ - from codeflash.languages.base import FunctionToOptimize + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.treesitter_utils import get_analyzer_for_file if context.current_depth >= context.max_depth: diff --git a/codeflash/languages/javascript/instrument.py b/codeflash/languages/javascript/instrument.py index d8ddad489..30e7fff7a 100644 --- a/codeflash/languages/javascript/instrument.py +++ b/codeflash/languages/javascript/instrument.py @@ -15,8 +15,7 @@ if TYPE_CHECKING: from codeflash.code_utils.code_position import CodePosition - -from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.discovery.functions_to_optimize import FunctionToOptimize class TestingMode: diff --git a/codeflash/languages/javascript/line_profiler.py b/codeflash/languages/javascript/line_profiler.py index c2ea2f495..57f046d4a 100644 --- a/codeflash/languages/javascript/line_profiler.py +++ b/codeflash/languages/javascript/line_profiler.py @@ -65,10 +65,10 @@ def instrument_source(self, source: str, file_path: Path, functions: list[Functi lines = source.splitlines(keepends=True) # Process functions in reverse order to preserve line numbers - for func in sorted(functions, key=lambda f: f.start_line, reverse=True): + for func in sorted(functions, key=lambda f: f.starting_line, reverse=True): func_lines = self._instrument_function(func, lines, file_path) - start_idx = func.start_line - 1 - end_idx = func.end_line + start_idx = func.starting_line - 1 + end_idx = func.ending_line lines = lines[:start_idx] + func_lines + lines[end_idx:] instrumented_source = "".join(lines) @@ -183,7 +183,7 @@ def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_ Instrumented function lines. """ - func_lines = lines[func.start_line - 1 : func.end_line] + func_lines = lines[func.starting_line - 1 : func.ending_line] instrumented_lines = [] # Parse the function to find executable lines @@ -194,7 +194,7 @@ def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_ tree = analyzer.parse(source.encode("utf8")) executable_lines = self._find_executable_lines(tree.root_node, source.encode("utf8")) except Exception as e: - logger.warning("Failed to parse function %s: %s", func.name, e) + logger.warning("Failed to parse function %s: %s", func.function_name, e) return func_lines # Add profiling to each executable line @@ -203,7 +203,7 @@ def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_ for local_idx, line in enumerate(func_lines): local_line_num = local_idx + 1 # 1-indexed within function - global_line_num = func.start_line + local_idx # Global line number in original file + global_line_num = func.starting_line + local_idx # Global line number in original file stripped = line.strip() # Add enterFunction() call after the opening brace of the function diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index d9ca2cfea..3b3323447 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -241,7 +241,7 @@ def discover_tests( # Match source functions to tests for func in source_functions: - if func.name in imported_names or func.name in source: + if func.function_name in imported_names or func.function_name in source: if func.qualified_name not in result: result[func.qualified_name] = [] for test_name in test_functions: diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index f9eb5f29f..3b9148819 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -63,13 +63,7 @@ def _find_jest_config(project_root: Path) -> Path | None: """ # Common Jest config file names, in order of preference - config_names = [ - "jest.config.ts", - "jest.config.js", - "jest.config.mjs", - "jest.config.cjs", - "jest.config.json", - ] + config_names = ["jest.config.ts", "jest.config.js", "jest.config.mjs", "jest.config.cjs", "jest.config.json"] # First check the project root itself for config_name in config_names: @@ -226,14 +220,7 @@ def _ensure_runtime_files(project_root: Path) -> None: install_cmd = get_package_install_command(project_root, "codeflash", dev=True) try: - result = subprocess.run( - install_cmd, - check=False, - cwd=project_root, - capture_output=True, - text=True, - timeout=120, - ) + result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120) if result.returncode == 0: logger.debug(f"Installed codeflash using {install_cmd[0]}") return diff --git a/codeflash/languages/javascript/tracer.py b/codeflash/languages/javascript/tracer.py index 7632db50b..2f5791ee0 100644 --- a/codeflash/languages/javascript/tracer.py +++ b/codeflash/languages/javascript/tracer.py @@ -64,10 +64,10 @@ def instrument_source(self, source: str, file_path: Path, functions: list[Functi lines = source.splitlines(keepends=True) # Process functions in reverse order to preserve line numbers - for func in sorted(functions, key=lambda f: f.start_line, reverse=True): + for func in sorted(functions, key=lambda f: f.starting_line, reverse=True): instrumented = self._instrument_function(func, lines, file_path) - start_idx = func.start_line - 1 - end_idx = func.end_line + start_idx = func.starting_line - 1 + end_idx = func.ending_line lines = lines[:start_idx] + instrumented + lines[end_idx:] instrumented_source = "".join(lines) @@ -281,11 +281,11 @@ def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_ Instrumented function lines. """ - func_lines = lines[func.start_line - 1 : func.end_line] + func_lines = lines[func.starting_line - 1 : func.ending_line] func_text = "".join(func_lines) # Detect function pattern - func_name = func.name + func_name = func.function_name is_arrow = "=>" in func_text.split("\n")[0] is_method = func.is_method is_async = func.is_async diff --git a/codeflash/languages/javascript/vitest_runner.py b/codeflash/languages/javascript/vitest_runner.py index 644128aa6..47a529dae 100644 --- a/codeflash/languages/javascript/vitest_runner.py +++ b/codeflash/languages/javascript/vitest_runner.py @@ -86,14 +86,7 @@ def _ensure_runtime_files(project_root: Path) -> None: install_cmd = get_package_install_command(project_root, "codeflash", dev=True) try: - result = subprocess.run( - install_cmd, - check=False, - cwd=project_root, - capture_output=True, - text=True, - timeout=120, - ) + result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120) if result.returncode == 0: logger.debug(f"Installed codeflash using {install_cmd[0]}") return diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index 8268ca8d3..ebaf47b19 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -159,7 +159,7 @@ def discover_tests( try: source = test_file.read_text() # Check if function name appears in test file - if func.name in source: + if func.function_name in source: result[func.qualified_name].append( TestInfo(test_name=test_file.stem, test_file=test_file, test_class=None) ) @@ -289,7 +289,7 @@ def find_helper_functions(self, function: FunctionToOptimize, project_root: Path ) except Exception as e: - logger.warning("Failed to find helpers for %s: %s", function.name, e) + logger.warning("Failed to find helpers for %s: %s", function.function_name, e) return helpers @@ -389,10 +389,10 @@ def find_references( line=ref.line, column=ref.column, end_line=ref.line, - end_column=ref.column + len(function.name), + end_column=ref.column + len(function.function_name), context=context.strip(), reference_type="call", - import_name=function.name, + import_name=function.function_name, caller_function=caller_function, ) ) @@ -400,7 +400,7 @@ def find_references( return result except Exception as e: - logger.warning("Failed to find references for %s: %s", function.name, e) + logger.warning("Failed to find references for %s: %s", function.function_name, e) return [] # === Code Transformation === @@ -433,7 +433,7 @@ def replace_function(self, source: str, function: FunctionToOptimize, new_source preexisting_objects=set(), ) except Exception as e: - logger.warning("Failed to replace function %s: %s", function.name, e) + logger.warning("Failed to replace function %s: %s", function.function_name, e) return source def format_code(self, source: str, file_path: Path | None = None) -> str: diff --git a/codeflash/languages/registry.py b/codeflash/languages/registry.py index 426e5ada2..e7b971fbe 100644 --- a/codeflash/languages/registry.py +++ b/codeflash/languages/registry.py @@ -30,6 +30,33 @@ # Cache of instantiated language support objects _SUPPORT_CACHE: dict[Language, LanguageSupport] = {} +# Flag to track if language modules have been imported +_languages_registered = False + + +def _ensure_languages_registered() -> None: + """Ensure all language support modules are imported and registered. + + This lazily imports the language support modules to avoid circular imports + at module load time. The imports trigger the @register_language decorators + which populate the registries. + """ + global _languages_registered + if _languages_registered: + return + + # Import support modules to trigger registration + # These imports are deferred to avoid circular imports + import contextlib + + with contextlib.suppress(ImportError): + from codeflash.languages.python import support as _ + + with contextlib.suppress(ImportError): + from codeflash.languages.javascript import support as _ # noqa: F401 + + _languages_registered = True + class UnsupportedLanguageError(Exception): """Raised when attempting to use an unsupported language.""" @@ -123,6 +150,10 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport: Raises: UnsupportedLanguageError: If the language is not supported. + Note: + This function lazily imports language support modules on first call + to avoid circular import issues at module load time. + Example: # By file path lang = get_language_support(Path("example.py")) @@ -137,6 +168,7 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport: lang = get_language_support("python") """ + _ensure_languages_registered() language: Language | None = None if isinstance(identifier, Language): @@ -179,6 +211,7 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport: def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None: + _ensure_languages_registered() language: Language | None = None if isinstance(formatter_cmd, str): formatter_cmd = [formatter_cmd] @@ -263,6 +296,7 @@ def detect_project_language(project_root: Path, module_root: Path) -> Language: UnsupportedLanguageError: If no supported language is detected. """ + _ensure_languages_registered() extension_counts: dict[str, int] = {} # Count files by extension @@ -290,6 +324,7 @@ def get_supported_languages() -> list[str]: List of language name strings. """ + _ensure_languages_registered() return [lang.value for lang in _LANGUAGE_REGISTRY] @@ -300,6 +335,7 @@ def get_supported_extensions() -> list[str]: List of extension strings (with leading dots). """ + _ensure_languages_registered() return list(_EXTENSION_REGISTRY.keys()) @@ -325,10 +361,12 @@ def clear_registry() -> None: Primarily useful for testing. """ + global _languages_registered _EXTENSION_REGISTRY.clear() _LANGUAGE_REGISTRY.clear() _SUPPORT_CACHE.clear() _FRAMEWORK_CACHE.clear() + _languages_registered = False def clear_cache() -> None: diff --git a/codeflash/models/function_types.py b/codeflash/models/function_types.py new file mode 100644 index 000000000..9ff1036ae --- /dev/null +++ b/codeflash/models/function_types.py @@ -0,0 +1,18 @@ +"""Simple function-related types with no dependencies. + +This module contains basic types used for function representation. +It is intentionally kept dependency-free to avoid circular imports. +""" + +from __future__ import annotations + +from pydantic.dataclasses import dataclass + + +@dataclass(frozen=True) +class FunctionParent: + name: str + type: str + + def __str__(self) -> str: + return f"{self.type}:{self.name}" diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 1e8509c66..5a5b0c5b5 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -598,13 +598,8 @@ class CodePosition: col_no: int -@dataclass(frozen=True) -class FunctionParent: - name: str - type: str - - def __str__(self) -> str: - return f"{self.type}:{self.name}" +# Re-export FunctionParent for backward compatibility +from codeflash.models.function_types import FunctionParent # noqa: E402 class OriginalCodeBaseline(BaseModel): diff --git a/tests/test_languages/test_code_context_extraction.py b/tests/test_languages/test_code_context_extraction.py index 21a0c26db..7ac0920cf 100644 --- a/tests/test_languages/test_code_context_extraction.py +++ b/tests/test_languages/test_code_context_extraction.py @@ -90,7 +90,7 @@ def test_arrow_function_with_implicit_return(self, js_support, temp_project): functions = js_support.discover_functions(file_path) assert len(functions) == 1 func = functions[0] - assert func.name == "multiply" + assert func.function_name == "multiply" context = js_support.extract_code_context(func, temp_project, temp_project) @@ -268,7 +268,7 @@ class CacheManager { file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - get_or_compute = next(f for f in functions if f.name == "getOrCompute") + get_or_compute = next(f for f in functions if f.function_name == "getOrCompute") context = js_support.extract_code_context(get_or_compute, temp_project, temp_project) @@ -370,7 +370,7 @@ def test_jsdoc_with_typedef_and_callback(self, js_support, temp_project): file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - func = next(f for f in functions if f.name == "validateUserData") + func = next(f for f in functions if f.function_name == "validateUserData") context = js_support.extract_code_context(func, temp_project, temp_project) @@ -466,7 +466,7 @@ def test_function_with_multiple_complex_constants(self, js_support, temp_project file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - func = next(f for f in functions if f.name == "fetchWithRetry") + func = next(f for f in functions if f.function_name == "fetchWithRetry") context = js_support.extract_code_context(func, temp_project, temp_project) @@ -615,7 +615,7 @@ def test_function_with_chain_of_helpers(self, js_support, temp_project): file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - process_func = next(f for f in functions if f.name == "processUserInput") + process_func = next(f for f in functions if f.function_name == "processUserInput") context = js_support.extract_code_context(process_func, temp_project, temp_project) @@ -670,7 +670,7 @@ def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - report_func = next(f for f in functions if f.name == "generateReport") + report_func = next(f for f in functions if f.function_name == "generateReport") context = js_support.extract_code_context(report_func, temp_project, temp_project) @@ -768,7 +768,7 @@ class Graph { file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - topo_sort = next(f for f in functions if f.name == "topologicalSort") + topo_sort = next(f for f in functions if f.function_name == "topologicalSort") context = js_support.extract_code_context(topo_sort, temp_project, temp_project) @@ -843,7 +843,7 @@ class MainClass { file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - main_method = next(f for f in functions if f.name == "mainMethod" and f.class_name == "MainClass") + main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass") context = js_support.extract_code_context(main_method, temp_project, temp_project) @@ -899,7 +899,7 @@ def test_helper_from_another_file_commonjs(self, js_support, temp_project): main_path.write_text(main_code, encoding="utf-8") functions = js_support.discover_functions(main_path) - main_func = next(f for f in functions if f.name == "sortFromAnotherFile") + main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile") context = js_support.extract_code_context(main_func, temp_project, temp_project) @@ -952,7 +952,7 @@ def test_helper_from_another_file_esm(self, js_support, temp_project): main_path.write_text(main_code, encoding="utf-8") functions = js_support.discover_functions(main_path) - process_func = next(f for f in functions if f.name == "processNumber") + process_func = next(f for f in functions if f.function_name == "processNumber") context = js_support.extract_code_context(process_func, temp_project, temp_project) @@ -1020,7 +1020,7 @@ def test_chained_imports_across_three_files(self, js_support, temp_project): main_path.write_text(main_code, encoding="utf-8") functions = js_support.discover_functions(main_path) - handle_func = next(f for f in functions if f.name == "handleUserInput") + handle_func = next(f for f in functions if f.function_name == "handleUserInput") context = js_support.extract_code_context(handle_func, temp_project, temp_project) @@ -1161,7 +1161,7 @@ class TypedCache { file_path.write_text(code, encoding="utf-8") functions = ts_support.discover_functions(file_path) - get_method = next(f for f in functions if f.name == "get") + get_method = next(f for f in functions if f.function_name == "get") context = ts_support.extract_code_context(get_method, temp_project, temp_project) @@ -1247,7 +1247,7 @@ def test_typescript_with_type_imports(self, ts_support, temp_project): service_path.write_text(service_code, encoding="utf-8") functions = ts_support.discover_functions(service_path) - func = next(f for f in functions if f.name == "createUser") + func = next(f for f in functions if f.function_name == "createUser") context = ts_support.extract_code_context(func, temp_project, temp_project) @@ -1331,7 +1331,7 @@ def test_mutually_recursive_even_odd(self, js_support, temp_project): file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - is_even = next(f for f in functions if f.name == "isEven") + is_even = next(f for f in functions if f.function_name == "isEven") context = js_support.extract_code_context(is_even, temp_project, temp_project) @@ -1393,7 +1393,7 @@ def test_complex_recursive_tree_traversal(self, js_support, temp_project): file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - collect_func = next(f for f in functions if f.name == "collectAllValues") + collect_func = next(f for f in functions if f.function_name == "collectAllValues") context = js_support.extract_code_context(collect_func, temp_project, temp_project) @@ -1458,7 +1458,7 @@ def test_async_function_chain(self, js_support, temp_project): file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - profile_func = next(f for f in functions if f.name == "fetchUserProfile") + profile_func = next(f for f in functions if f.function_name == "fetchUserProfile") context = js_support.extract_code_context(profile_func, temp_project, temp_project) @@ -1513,7 +1513,7 @@ class Counter { file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - increment_func = next(fn for fn in functions if fn.name == "increment") + increment_func = next(fn for fn in functions if fn.function_name == "increment") # Step 1: Extract code context context = js_support.extract_code_context(increment_func, temp_project, temp_project) @@ -1635,7 +1635,7 @@ def test_generator_function(self, js_support, temp_project): file_path.write_text(code, encoding="utf-8") functions = js_support.discover_functions(file_path) - range_func = next(f for f in functions if f.name == "range") + range_func = next(f for f in functions if f.function_name == "range") context = js_support.extract_code_context(range_func, temp_project, temp_project) @@ -1772,7 +1772,7 @@ class Calculator { functions = js_support.discover_functions(file_path) for func in functions: - if func.name != "constructor": + if func.function_name != "constructor": context = js_support.extract_code_context(func, temp_project, temp_project) is_valid = js_support.validate_syntax(context.target_code) assert is_valid is True, f"Invalid syntax for {func.name}:\n{context.target_code}" diff --git a/tests/test_languages/test_javascript_test_discovery.py b/tests/test_languages/test_javascript_test_discovery.py index 182535f7a..9166b589e 100644 --- a/tests/test_languages/test_javascript_test_discovery.py +++ b/tests/test_languages/test_javascript_test_discovery.py @@ -1715,7 +1715,7 @@ class Calculator { functions = js_support.discover_functions(source_file) # Check qualified names include class - add_func = next((f for f in functions if f.name == "add"), None) + add_func = next((f for f in functions if f.function_name == "add"), None) assert add_func is not None assert add_func.class_name == "Calculator" diff --git a/tests/test_languages/test_js_code_extractor.py b/tests/test_languages/test_js_code_extractor.py index a4b2e9e8f..b1dcee81f 100644 --- a/tests/test_languages/test_js_code_extractor.py +++ b/tests/test_languages/test_js_code_extractor.py @@ -39,7 +39,7 @@ def test_discover_class_methods(self, js_support, cjs_project): calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - method_names = {f.name for f in functions} + method_names = {f.function_name for f in functions} expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"} assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}" @@ -51,15 +51,15 @@ def test_class_method_has_correct_parent(self, js_support, cjs_project): for func in functions: # All methods should belong to Calculator class - assert func.is_method is True, f"{func.name} should be a method" - assert func.class_name == "Calculator", f"{func.name} should belong to Calculator, got {func.class_name}" + assert func.is_method is True, f"{func.function_name} should be a method" + assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}" def test_extract_permutation_code(self, js_support, cjs_project): """Test permutation method code extraction.""" calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - permutation_func = next(f for f in functions if f.name == "permutation") + permutation_func = next(f for f in functions if f.function_name == "permutation") context = js_support.extract_code_context( function=permutation_func, project_root=cjs_project, module_root=cjs_project @@ -95,7 +95,7 @@ def test_extract_context_includes_direct_helpers(self, js_support, cjs_project): calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - permutation_func = next(f for f in functions if f.name == "permutation") + permutation_func = next(f for f in functions if f.function_name == "permutation") context = js_support.extract_code_context( function=permutation_func, project_root=cjs_project, module_root=cjs_project @@ -136,7 +136,7 @@ def test_extract_compound_interest_code(self, js_support, cjs_project): calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - compound_func = next(f for f in functions if f.name == "calculateCompoundInterest") + compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest") context = js_support.extract_code_context( function=compound_func, project_root=cjs_project, module_root=cjs_project @@ -182,7 +182,7 @@ def test_extract_compound_interest_helpers(self, js_support, cjs_project): calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - compound_func = next(f for f in functions if f.name == "calculateCompoundInterest") + compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest") context = js_support.extract_code_context( function=compound_func, project_root=cjs_project, module_root=cjs_project @@ -266,7 +266,7 @@ def test_extract_context_includes_imports(self, js_support, cjs_project): calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - compound_func = next(f for f in functions if f.name == "calculateCompoundInterest") + compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest") context = js_support.extract_code_context( function=compound_func, project_root=cjs_project, module_root=cjs_project @@ -287,7 +287,7 @@ def test_extract_static_method(self, js_support, cjs_project): calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - quick_add_func = next(f for f in functions if f.name == "quickAdd") + quick_add_func = next(f for f in functions if f.function_name == "quickAdd") context = js_support.extract_code_context( function=quick_add_func, project_root=cjs_project, module_root=cjs_project @@ -352,7 +352,7 @@ def test_discover_esm_methods(self, js_support, esm_project): calculator_file = esm_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - method_names = {f.name for f in functions} + method_names = {f.function_name for f in functions} # Should find same methods as CJS version expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"} @@ -363,7 +363,7 @@ def test_esm_permutation_extraction(self, js_support, esm_project): calculator_file = esm_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - permutation_func = next(f for f in functions if f.name == "permutation") + permutation_func = next(f for f in functions if f.function_name == "permutation") context = js_support.extract_code_context( function=permutation_func, project_root=esm_project, module_root=esm_project @@ -413,7 +413,7 @@ def test_esm_compound_interest_extraction(self, js_support, esm_project): calculator_file = esm_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - compound_func = next(f for f in functions if f.name == "calculateCompoundInterest") + compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest") context = js_support.extract_code_context( function=compound_func, project_root=esm_project, module_root=esm_project @@ -539,7 +539,7 @@ def test_discover_ts_methods(self, ts_support, ts_project): calculator_file = ts_project / "calculator.ts" functions = ts_support.discover_functions(calculator_file) - method_names = {f.name for f in functions} + method_names = {f.function_name for f in functions} # TypeScript has additional getHistory method expected_methods = {"calculateCompoundInterest", "permutation", "getHistory", "quickAdd"} @@ -550,7 +550,7 @@ def test_ts_permutation_extraction(self, ts_support, ts_project): calculator_file = ts_project / "calculator.ts" functions = ts_support.discover_functions(calculator_file) - permutation_func = next(f for f in functions if f.name == "permutation") + permutation_func = next(f for f in functions if f.function_name == "permutation") context = ts_support.extract_code_context( function=permutation_func, project_root=ts_project, module_root=ts_project @@ -603,7 +603,7 @@ def test_ts_compound_interest_extraction(self, ts_support, ts_project): calculator_file = ts_project / "calculator.ts" functions = ts_support.discover_functions(calculator_file) - compound_func = next(f for f in functions if f.name == "calculateCompoundInterest") + compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest") context = ts_support.extract_code_context( function=compound_func, project_root=ts_project, module_root=ts_project @@ -712,7 +712,7 @@ def test_standalone_function(self, js_support, tmp_path): test_file.write_text(source) functions = js_support.discover_functions(test_file) - func = next(f for f in functions if f.name == "standalone") + func = next(f for f in functions if f.function_name == "standalone") context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) @@ -745,7 +745,7 @@ def test_external_package_excluded(self, js_support, tmp_path): test_file.write_text(source) functions = js_support.discover_functions(test_file) - func = next(f for f in functions if f.name == "processArray") + func = next(f for f in functions if f.function_name == "processArray") context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) @@ -780,7 +780,7 @@ def test_recursive_function(self, js_support, tmp_path): test_file.write_text(source) functions = js_support.discover_functions(test_file) - func = next(f for f in functions if f.name == "fibonacci") + func = next(f for f in functions if f.function_name == "fibonacci") context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) @@ -813,7 +813,7 @@ def test_arrow_function_helper(self, js_support, tmp_path): test_file.write_text(source) functions = js_support.discover_functions(test_file) - func = next(f for f in functions if f.name == "processValue") + func = next(f for f in functions if f.function_name == "processValue") context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) @@ -871,7 +871,7 @@ class Counter { test_file.write_text(source) functions = js_support.discover_functions(test_file) - increment_func = next(f for f in functions if f.name == "increment") + increment_func = next(f for f in functions if f.function_name == "increment") context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path) @@ -910,7 +910,7 @@ class MathUtils { test_file.write_text(source) functions = js_support.discover_functions(test_file) - add_func = next(f for f in functions if f.name == "add") + add_func = next(f for f in functions if f.function_name == "add") context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path) @@ -948,7 +948,7 @@ class User { test_file.write_text(source) functions = ts_support.discover_functions(test_file) - get_name_func = next(f for f in functions if f.name == "getName") + get_name_func = next(f for f in functions if f.function_name == "getName") context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path) @@ -989,7 +989,7 @@ class Config { test_file.write_text(source) functions = ts_support.discover_functions(test_file) - get_url_func = next(f for f in functions if f.name == "getUrl") + get_url_func = next(f for f in functions if f.function_name == "getUrl") context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path) @@ -1030,7 +1030,7 @@ class Logger { test_file.write_text(source) functions = js_support.discover_functions(test_file) - get_prefix_func = next(f for f in functions if f.name == "getPrefix") + get_prefix_func = next(f for f in functions if f.function_name == "getPrefix") context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path) @@ -1072,7 +1072,7 @@ class Factory { test_file.write_text(source) functions = js_support.discover_functions(test_file) - create_func = next(f for f in functions if f.name == "create") + create_func = next(f for f in functions if f.function_name == "create") context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path) @@ -1114,19 +1114,20 @@ def test_function_optimizer_workflow(self, cjs_project): calculator_file = cjs_project / "calculator.js" functions = js_support.discover_functions(calculator_file) - target = next(f for f in functions if f.name == "permutation") + target = next(f for f in functions if f.function_name == "permutation") parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents] func = FunctionToOptimize( - function_name=target.name, + function_name=target.function_name, file_path=target.file_path, parents=parents, - starting_line=target.start_line, - ending_line=target.end_line, - starting_col=target.start_col, - ending_col=target.end_col, + starting_line=target.starting_line, + ending_line=target.ending_line, + starting_col=target.starting_col, + ending_col=target.ending_col, is_async=target.is_async, + is_method=target.is_method, language=target.language, ) @@ -1223,7 +1224,7 @@ def test_extract_same_file_interface_from_parameter(self, ts_support, tmp_path): test_file.write_text(source) functions = ts_support.discover_functions(test_file) - distance_func = next(f for f in functions if f.name == "distance") + distance_func = next(f for f in functions if f.function_name == "distance") context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path) @@ -1267,7 +1268,7 @@ def test_extract_same_file_enum_from_parameter(self, ts_support, tmp_path): test_file.write_text(source) functions = ts_support.discover_functions(test_file) - process_func = next(f for f in functions if f.name == "processStatus") + process_func = next(f for f in functions if f.function_name == "processStatus") context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path) @@ -1304,7 +1305,7 @@ def test_extract_same_file_type_alias_from_return_type(self, ts_support, tmp_pat test_file.write_text(source) functions = ts_support.discover_functions(test_file) - compute_func = next(f for f in functions if f.name == "compute") + compute_func = next(f for f in functions if f.function_name == "compute") context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path) @@ -1348,7 +1349,7 @@ class Service { test_file.write_text(source) functions = ts_support.discover_functions(test_file) - get_timeout_func = next(f for f in functions if f.name == "getTimeout") + get_timeout_func = next(f for f in functions if f.function_name == "getTimeout") context = ts_support.extract_code_context( function=get_timeout_func, project_root=tmp_path, module_root=tmp_path @@ -1381,7 +1382,7 @@ def test_primitive_types_not_included(self, ts_support, tmp_path): test_file.write_text(source) functions = ts_support.discover_functions(test_file) - add_func = next(f for f in functions if f.name == "add") + add_func = next(f for f in functions if f.function_name == "add") context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path) @@ -1414,7 +1415,7 @@ def test_extract_multiple_types(self, ts_support, tmp_path): test_file.write_text(source) functions = ts_support.discover_functions(test_file) - create_rect_func = next(f for f in functions if f.name == "createRect") + create_rect_func = next(f for f in functions if f.function_name == "createRect") context = ts_support.extract_code_context( function=create_rect_func, project_root=tmp_path, module_root=tmp_path @@ -1462,7 +1463,7 @@ def test_extract_imported_type_definition(self, ts_support, ts_types_project): """) functions = ts_support.discover_functions(geometry_file) - calc_distance_func = next(f for f in functions if f.name == "calculateDistance") + calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance") context = ts_support.extract_code_context( function=calc_distance_func, project_root=ts_types_project, module_root=ts_types_project @@ -1515,7 +1516,7 @@ def test_type_with_jsdoc_included(self, ts_support, tmp_path): test_file.write_text(source) functions = ts_support.discover_functions(test_file) - greet_func = next(f for f in functions if f.name == "greetUser") + greet_func = next(f for f in functions if f.function_name == "greetUser") context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path) diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index 3d703aa34..bb3db03de 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -711,7 +711,7 @@ def test_replace_preserves_surrounding_code(self, js_support, temp_project): file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - target_func = next(f for f in functions if f.name == "targetFunction") + target_func = next(f for f in functions if f.function_name == "targetFunction") optimized_code = """\ function targetFunction(x) { @@ -763,7 +763,7 @@ class Calculator { file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - add_method = next(f for f in functions if f.name == "add") + add_method = next(f for f in functions if f.function_name == "add") # Optimized version provided in class context optimized_code = """\ @@ -826,7 +826,7 @@ class DataProcessor { file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - process_method = next(f for f in functions if f.name == "process") + process_method = next(f for f in functions if f.function_name == "process") optimized_code = """\ class DataProcessor { @@ -948,7 +948,7 @@ class Cache { file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - get_method = next(f for f in functions if f.name == "get") + get_method = next(f for f in functions if f.function_name == "get") optimized_code = """\ class Cache { @@ -1050,7 +1050,7 @@ class ApiClient { file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - get_method = next(f for f in functions if f.name == "get") + get_method = next(f for f in functions if f.function_name == "get") optimized_code = """\ class ApiClient { @@ -1181,7 +1181,7 @@ class Container { file_path.write_text(original_source, encoding="utf-8") functions = ts_support.discover_functions(file_path) - get_all_method = next(f for f in functions if f.name == "getAll") + get_all_method = next(f for f in functions if f.function_name == "getAll") optimized_code = """\ class Container { @@ -1234,7 +1234,7 @@ def test_replace_typescript_interface_typed_function(self, ts_support, temp_proj file_path.write_text(original_source, encoding="utf-8") functions = ts_support.discover_functions(file_path) - func = next(f for f in functions if f.name == "createUser") + func = next(f for f in functions if f.function_name == "createUser") optimized_code = """\ function createUser(name: string, email: string): User { @@ -1289,7 +1289,7 @@ def test_replace_function_with_nested_functions(self, js_support, temp_project): file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - process_func = next(f for f in functions if f.name == "processItems") + process_func = next(f for f in functions if f.function_name == "processItems") optimized_code = """\ function processItems(items) { @@ -1336,7 +1336,7 @@ class MathUtils { # First replacement: sum method functions = js_support.discover_functions(file_path) - sum_method = next(f for f in functions if f.name == "sum") + sum_method = next(f for f in functions if f.function_name == "sum") optimized_sum = """\ class MathUtils { @@ -1554,7 +1554,7 @@ def test_replace_exported_function_commonjs(self, js_support, temp_project): file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - main_func = next(f for f in functions if f.name == "main") + main_func = next(f for f in functions if f.function_name == "main") optimized_code = """\ function main(data) { @@ -1597,7 +1597,7 @@ def test_replace_exported_function_esm(self, js_support, temp_project): file_path.write_text(original_source, encoding="utf-8") functions = js_support.discover_functions(file_path) - main_func = next(f for f in functions if f.name == "main") + main_func = next(f for f in functions if f.function_name == "main") optimized_code = """\ export function main(data) { @@ -1756,7 +1756,7 @@ def test_code_replacer_for_class_method(ts_support, temp_project): # find function target_func_info = None for func in functions: - if func.name == target_func and func.parents[0].name == parent_class: + if func.function_name == target_func and func.parents[0].name == parent_class: target_func_info = func break assert target_func_info is not None diff --git a/tests/test_languages/test_multi_file_code_replacer.py b/tests/test_languages/test_multi_file_code_replacer.py index cd21de104..65f3930e5 100644 --- a/tests/test_languages/test_multi_file_code_replacer.py +++ b/tests/test_languages/test_multi_file_code_replacer.py @@ -113,19 +113,20 @@ def test_js_replcement() -> None: functions = js_support.discover_functions(main_file) target = None for func in functions: - if func.name == "calculateStats": + if func.function_name == "calculateStats": target = func break assert target is not None func = FunctionToOptimize( - function_name=target.name, + function_name=target.function_name, file_path=target.file_path, parents=target.parents, - starting_line=target.start_line, - ending_line=target.end_line, - starting_col=target.start_col, - ending_col=target.end_col, + starting_line=target.starting_line, + ending_line=target.ending_line, + starting_col=target.starting_col, + ending_col=target.ending_col, is_async=target.is_async, + is_method=target.is_method, language=target.language, ) test_config = TestConfig( diff --git a/tests/test_languages/test_treesitter_utils.py b/tests/test_languages/test_treesitter_utils.py index a557a84dc..e5e776a11 100644 --- a/tests/test_languages/test_treesitter_utils.py +++ b/tests/test_languages/test_treesitter_utils.py @@ -136,7 +136,7 @@ def test_find_function_declaration(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "add" + assert functions[0].name == "add" assert functions[0].is_arrow is False assert functions[0].is_async is False assert functions[0].is_method is False @@ -151,7 +151,7 @@ def test_find_arrow_function(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "add" + assert functions[0].name == "add" assert functions[0].is_arrow is True def test_find_arrow_function_concise(self, js_analyzer): @@ -160,7 +160,7 @@ def test_find_arrow_function_concise(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "double" + assert functions[0].name == "double" assert functions[0].is_arrow is True def test_find_async_function(self, js_analyzer): @@ -173,7 +173,7 @@ def test_find_async_function(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "fetchData" + assert functions[0].name == "fetchData" assert functions[0].is_async is True def test_find_class_methods(self, js_analyzer): @@ -188,7 +188,7 @@ class Calculator { functions = js_analyzer.find_functions(code, include_methods=True) assert len(functions) == 1 - assert functions[0].function_name == "add" + assert functions[0].name == "add" assert functions[0].is_method is True assert functions[0].class_name == "Calculator" @@ -208,7 +208,7 @@ class Calculator { functions = js_analyzer.find_functions(code, include_methods=False) assert len(functions) == 1 - assert functions[0].function_name == "standalone" + assert functions[0].name == "standalone" def test_exclude_arrow_functions(self, js_analyzer): """Test excluding arrow functions.""" @@ -222,7 +222,7 @@ def test_exclude_arrow_functions(self, js_analyzer): functions = js_analyzer.find_functions(code, include_arrow_functions=False) assert len(functions) == 1 - assert functions[0].function_name == "regular" + assert functions[0].name == "regular" def test_find_generator_function(self, js_analyzer): """Test finding generator functions.""" @@ -235,7 +235,7 @@ def test_find_generator_function(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "numberGenerator" + assert functions[0].name == "numberGenerator" assert functions[0].is_generator is True def test_function_line_numbers(self, js_analyzer): @@ -291,7 +291,7 @@ def test_require_name_filters_anonymous(self, js_analyzer): functions = js_analyzer.find_functions(code, require_name=True) assert len(functions) == 1 - assert functions[0].function_name == "named" + assert functions[0].name == "named" def test_function_expression_in_variable(self, js_analyzer): """Test function expression assigned to variable.""" @@ -303,7 +303,7 @@ def test_function_expression_in_variable(self, js_analyzer): functions = js_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "add" + assert functions[0].name == "add" class TestFindImports: @@ -515,7 +515,7 @@ def test_find_typed_function(self, ts_analyzer): functions = ts_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "add" + assert functions[0].name == "add" def test_find_interface_method(self, ts_analyzer): """Test that interface methods are not found (they're declarations).""" @@ -544,4 +544,4 @@ def test_find_generic_function(self, ts_analyzer): functions = ts_analyzer.find_functions(code) assert len(functions) == 1 - assert functions[0].function_name == "identity" + assert functions[0].name == "identity" From 09868a4b8c6ce8f7d5f77e8d62eeff56e56fbf97 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Mon, 2 Feb 2026 16:54:48 -0800 Subject: [PATCH 4/5] fix: resolve CI failures for mypy, Python 3.9, and Windows - Fix mypy errors by correcting import paths for FunctionParent and Language - Add type annotations to _handle_config_loading in main.py - Fix has_existing_config tuple unpacking in cli.py - Replace env_utils.os.environ with os.environ in main.py - Add 'from __future__ import annotations' to test files for Python 3.9 compatibility - Use .as_posix() for paths in markdown code fences for Windows compatibility - Use .resolve() on temp paths in tests to handle Windows short path names Co-Authored-By: Claude Opus 4.5 --- codeflash/cli_cmds/cli.py | 5 +++-- codeflash/code_utils/code_extractor.py | 2 +- codeflash/code_utils/static_analysis.py | 2 +- codeflash/context/code_context_extractor.py | 3 +-- codeflash/main.py | 13 ++++++++----- tests/test_languages/test_find_references.py | 2 ++ .../test_javascript_instrumentation.py | 2 ++ tests/test_languages/test_javascript_test_runner.py | 4 ++-- 8 files changed, 20 insertions(+), 13 deletions(-) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 11a9bf02e..4c7654b6d 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -358,7 +358,7 @@ def _handle_show_config() -> None: detected = detect_project(project_root) # Check if config exists or is auto-detected - config_exists = has_existing_config(project_root) + config_exists, _ = has_existing_config(project_root) status = "Saved config" if config_exists else "Auto-detected (not saved)" console.print() @@ -400,7 +400,8 @@ def _handle_reset_config(confirm: bool = True) -> None: project_root = Path.cwd() - if not has_existing_config(project_root): + config_exists, _ = has_existing_config(project_root) + if not config_exists: console.print("[yellow]No Codeflash configuration found to remove.[/yellow]") return diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index dc198b0f8..4e19f53be 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1709,7 +1709,7 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro context_len += len(context_code) if caller_contexts: - fn_call_context += f"```{lang_hint}:{path_relative}\n" + fn_call_context += f"```{lang_hint}:{path_relative.as_posix()}\n" fn_call_context += "\n".join(caller_contexts) fn_call_context += "\n```\n" diff --git a/codeflash/code_utils/static_analysis.py b/codeflash/code_utils/static_analysis.py index 0151e29e7..a0d04bfb1 100644 --- a/codeflash/code_utils/static_analysis.py +++ b/codeflash/code_utils/static_analysis.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict, field_validator if TYPE_CHECKING: - from codeflash.models.models import FunctionParent + from codeflash.models.function_types import FunctionParent ObjectDefT = TypeVar("ObjectDefT", ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 466d2aa46..18db28856 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -23,8 +23,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 # Language support imports for multi-language code context extraction -from codeflash.languages import is_python -from codeflash.languages.base import Language +from codeflash.languages import Language, is_python from codeflash.models.models import ( CodeContextType, CodeOptimizationContext, diff --git a/codeflash/main.py b/codeflash/main.py index 2454b51ce..ca4a5e825 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -4,7 +4,9 @@ solved problem, please reach out to us at careers@codeflash.ai. We're hiring! """ +import os import sys +from argparse import Namespace from pathlib import Path from codeflash.cli_cmds.cli import parse_args, process_pyproject_config @@ -41,9 +43,10 @@ def main() -> None: ask_run_end_to_end_test(args) else: # Check for first-run experience (no config exists) - args = _handle_config_loading(args) - if args is None: + loaded_args = _handle_config_loading(args) + if loaded_args is None: sys.exit(0) + args = loaded_args if not env_utils.check_formatter_installed(args.formatter_cmds): return @@ -56,7 +59,7 @@ def main() -> None: optimizer.run_with_args(args) -def _handle_config_loading(args): +def _handle_config_loading(args: Namespace) -> Namespace | None: """Handle config loading with first-run experience support. If no config exists and not in CI, triggers the first-run experience. @@ -74,13 +77,13 @@ def _handle_config_loading(args): # Check if we're in CI environment is_ci = any( var in ("true", "1", "True") - for var in [env_utils.os.environ.get("CI", ""), env_utils.os.environ.get("GITHUB_ACTIONS", "")] + for var in [os.environ.get("CI", ""), os.environ.get("GITHUB_ACTIONS", "")] ) # Check if first run (no config exists) if is_first_run() and not is_ci: # Skip API key check if already set - skip_api_key = bool(env_utils.os.environ.get("CODEFLASH_API_KEY")) + skip_api_key = bool(os.environ.get("CODEFLASH_API_KEY")) # Handle first-run experience result = handle_first_run(args=args, skip_confirm=getattr(args, "yes", False), skip_api_key=skip_api_key) diff --git a/tests/test_languages/test_find_references.py b/tests/test_languages/test_find_references.py index 970ff05d9..537e3ef0b 100644 --- a/tests/test_languages/test_find_references.py +++ b/tests/test_languages/test_find_references.py @@ -8,6 +8,8 @@ 2. The formatted markdown output from _format_references_as_markdown """ +from __future__ import annotations + import pytest from pathlib import Path diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index 9896d1d69..ba25a3af5 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -3,6 +3,8 @@ This module tests the line profiling and tracing instrumentation for JavaScript code. """ +from __future__ import annotations + import tempfile from pathlib import Path diff --git a/tests/test_languages/test_javascript_test_runner.py b/tests/test_languages/test_javascript_test_runner.py index 8b3c6205f..87e712038 100644 --- a/tests/test_languages/test_javascript_test_runner.py +++ b/tests/test_languages/test_javascript_test_runner.py @@ -18,7 +18,7 @@ def test_behavioral_tests_adds_roots_for_test_directories(self): # Create mock test files in a test directory with tempfile.TemporaryDirectory() as tmpdir: - tmpdir_path = Path(tmpdir) + tmpdir_path = Path(tmpdir).resolve() test_dir = tmpdir_path / "test" test_dir.mkdir() @@ -90,7 +90,7 @@ def test_benchmarking_tests_adds_roots_for_test_directories(self): from codeflash.models.test_type import TestType with tempfile.TemporaryDirectory() as tmpdir: - tmpdir_path = Path(tmpdir) + tmpdir_path = Path(tmpdir).resolve() test_dir = tmpdir_path / "test" test_dir.mkdir() From 44e1cf1a337c6fbf56ca3ac233baa57f498c9e53 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Mon, 2 Feb 2026 17:00:35 -0800 Subject: [PATCH 5/5] lint fix --- codeflash/main.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/codeflash/main.py b/codeflash/main.py index ca4a5e825..690c1ae98 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -4,10 +4,12 @@ solved problem, please reach out to us at careers@codeflash.ai. We're hiring! """ +from __future__ import annotations + import os import sys -from argparse import Namespace from pathlib import Path +from typing import TYPE_CHECKING from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test @@ -19,6 +21,9 @@ from codeflash.telemetry import posthog_cf from codeflash.telemetry.sentry import init_sentry +if TYPE_CHECKING: + from argparse import Namespace + def main() -> None: """Entry point for the codeflash command-line interface.""" @@ -76,8 +81,7 @@ def _handle_config_loading(args: Namespace) -> Namespace | None: # Check if we're in CI environment is_ci = any( - var in ("true", "1", "True") - for var in [os.environ.get("CI", ""), os.environ.get("GITHUB_ACTIONS", "")] + var in ("true", "1", "True") for var in [os.environ.get("CI", ""), os.environ.get("GITHUB_ACTIONS", "")] ) # Check if first run (no config exists)