diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 03ad1529c..feb65f645 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1563,23 +1563,228 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo def get_opt_review_metrics( source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language ) -> str: - if language != Language.PYTHON: - # TODO: {Claude} handle function refrences for other languages - return "" + """Get function reference metrics for optimization review. + + Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript. + + Args: + source_code: Source code of the file containing the function. + file_path: Path to the file. + qualified_name: Qualified name of the function (e.g., "module.ClassName.method"). + project_root: Root of the project. + tests_root: Root of the tests directory. + language: The programming language. + + Returns: + Markdown-formatted string with code blocks showing calling functions. + """ + from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo + from codeflash.languages.registry import get_language_support + start_time = time.perf_counter() + try: + # Get the language support + lang_support = get_language_support(language) + if lang_support is None: + return "" + + # Parse qualified name to get function name and class name qualified_name_split = qualified_name.rsplit(".", maxsplit=1) if len(qualified_name_split) == 1: - target_function, target_class = qualified_name_split[0], None + function_name, class_name = qualified_name_split[0], None else: - target_function, target_class = qualified_name_split[1], qualified_name_split[0] - matches = get_fn_references_jedi( - source_code, file_path, project_root, target_function, target_class - ) # jedi is not perfect, it doesn't capture aliased references - calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root) + function_name, class_name = qualified_name_split[1], qualified_name_split[0] + + # Create a FunctionInfo for the function + # We don't have full line info here, so we'll use defaults + parents = () + if class_name: + parents = (ParentInfo(name=class_name, type="ClassDef"),) + + func_info = FunctionInfo( + name=function_name, + file_path=file_path, + start_line=1, + end_line=1, + parents=parents, + language=language, + ) + + # Find references using language support + references = lang_support.find_references(func_info, project_root, tests_root, max_files=500) + + if not references: + return "" + + # Format references as markdown code blocks + calling_fns_details = _format_references_as_markdown( + references, file_path, project_root, language + ) + except Exception as e: + logger.debug(f"Error getting function references: {e}") calling_fns_details = "" - logger.debug(f"Investigate {e}") + end_time = time.perf_counter() logger.debug(f"Got function references in {end_time - start_time:.2f} seconds") return calling_fns_details + + +def _format_references_as_markdown( + references: list, file_path: Path, project_root: Path, language: Language +) -> str: + """Format references as markdown code blocks with calling function code. + + Args: + references: List of ReferenceInfo objects. + file_path: Path to the source file (to exclude). + project_root: Root of the project. + language: The programming language. + + Returns: + Markdown-formatted string. + """ + # Group references by file + refs_by_file: dict[Path, list] = {} + for ref in references: + # Exclude the source file's definition/import references + if ref.file_path == file_path and ref.reference_type in ("import", "reexport"): + continue + + if ref.file_path not in refs_by_file: + refs_by_file[ref.file_path] = [] + refs_by_file[ref.file_path].append(ref) + + fn_call_context = "" + context_len = 0 + + for ref_file, file_refs in refs_by_file.items(): + if context_len > MAX_CONTEXT_LEN_REVIEW: + break + + try: + path_relative = ref_file.relative_to(project_root) + except ValueError: + continue + + # Get syntax highlighting language + ext = ref_file.suffix.lstrip(".") + if language == Language.PYTHON: + lang_hint = "python" + elif ext in ("ts", "tsx"): + lang_hint = "typescript" + else: + lang_hint = "javascript" + + # Read the file to extract calling function context + try: + file_content = ref_file.read_text(encoding="utf-8") + lines = file_content.splitlines() + except Exception: + continue + + # Get unique caller functions from this file + callers_seen: set[str] = set() + caller_contexts: list[str] = [] + + for ref in file_refs: + caller = ref.caller_function or "" + if caller in callers_seen: + continue + callers_seen.add(caller) + + # Extract context around the reference + if ref.caller_function: + # Try to extract the full calling function + func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language) + if func_code: + caller_contexts.append(func_code) + context_len += len(func_code) + else: + # Module-level call - show a few lines of context + start_line = max(0, ref.line - 3) + end_line = min(len(lines), ref.line + 2) + context_code = "\n".join(lines[start_line:end_line]) + caller_contexts.append(context_code) + context_len += len(context_code) + + if caller_contexts: + fn_call_context += f"```{lang_hint}:{path_relative}\n" + fn_call_context += "\n".join(caller_contexts) + fn_call_context += "\n```\n" + + return fn_call_context + + +def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None: + """Extract the source code of a calling function. + + Args: + source_code: Full source code of the file. + function_name: Name of the function to extract. + ref_line: Line number where the reference is. + language: The programming language. + + Returns: + Source code of the function, or None if not found. + """ + if language == Language.PYTHON: + return _extract_calling_function_python(source_code, function_name, ref_line) + else: + return _extract_calling_function_js(source_code, function_name, ref_line) + + +def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None: + """Extract the source code of a calling function in Python.""" + try: + import ast + + tree = ast.parse(source_code) + lines = source_code.splitlines() + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if node.name == function_name: + # Check if the reference line is within this function + start_line = node.lineno + end_line = node.end_lineno or start_line + if start_line <= ref_line <= end_line: + return "\n".join(lines[start_line - 1 : end_line]) + return None + except Exception: + return None + + +def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None: + """Extract the source code of a calling function in JavaScript/TypeScript. + + Args: + source_code: Full source code of the file. + function_name: Name of the function to extract. + ref_line: Line number where the reference is (helps identify the right function). + + Returns: + Source code of the function, or None if not found. + """ + try: + from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage + + # Try TypeScript first, fall back to JavaScript + for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]: + try: + analyzer = TreeSitterAnalyzer(lang) + functions = analyzer.find_functions(source_code, include_methods=True) + + for func in functions: + if func.name == function_name: + # Check if the reference line is within this function + if func.start_line <= ref_line <= func.end_line: + return func.source_text + break + except Exception: + continue + + return None + except Exception: + return None diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 152f22320..6e3ea4417 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -236,6 +236,37 @@ class FunctionFilterCriteria: max_lines: int | None = None +@dataclass +class ReferenceInfo: + """Information about a reference (call site) to a function. + + This class captures information about where a function is called + from, including the file, line number, context, and caller function. + + Attributes: + file_path: Path to the file containing the reference. + line: Line number (1-indexed). + column: Column number (0-indexed). + end_line: End line number (1-indexed). + end_column: End column number (0-indexed). + context: The line of code containing the reference. + reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport"). + import_name: Name used to import the function (may differ from original). + caller_function: Name of the function containing this reference (or None for module-level). + + """ + + file_path: Path + line: int + column: int + end_line: int + end_column: int + context: str + reference_type: str + import_name: str | None + caller_function: str | None = None + + @runtime_checkable class LanguageSupport(Protocol): """Protocol defining what a language implementation must provide. @@ -357,6 +388,29 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l """ ... + def find_references( + self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500 + ) -> list[ReferenceInfo]: + """Find all references (call sites) to a function across the codebase. + + This method finds all places where a function is called, including: + - Direct calls + - Callbacks (passed to other functions) + - Memoized versions + - Re-exports + + Args: + function: The function to find references for. + project_root: Root of the project to search. + tests_root: Root of tests directory (references in tests are excluded). + max_files: Maximum number of files to search. + + Returns: + List of ReferenceInfo objects describing each reference location. + + """ + ... + # === Code Transformation === def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str: diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py new file mode 100644 index 000000000..922181bed --- /dev/null +++ b/codeflash/languages/javascript/find_references.py @@ -0,0 +1,861 @@ +"""Find references for JavaScript/TypeScript functions. + +This module provides functionality to find all references (call sites) of a function +across a JavaScript/TypeScript codebase. Similar to Jedi's find_references for Python, +this uses tree-sitter to parse and analyze code. + +Key features: +- Finds all call sites of a function across multiple files +- Handles various import patterns (named, default, namespace, re-exports, aliases) +- Supports both ES modules and CommonJS +- Handles memoized functions, callbacks, and method calls +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tree_sitter import Node + + from codeflash.languages.treesitter_utils import ExportInfo, ImportInfo, TreeSitterAnalyzer + +logger = logging.getLogger(__name__) + + +@dataclass +class Reference: + """Represents a reference (call site) to a function.""" + + file_path: Path # File containing the reference + line: int # 1-indexed line number + column: int # 0-indexed column number + end_line: int # 1-indexed end line + end_column: int # 0-indexed end column + context: str # The line of code containing the reference + reference_type: str # Type: "call", "callback", "memoized", "import", "reexport" + import_name: str | None # Name used to import the function (may differ from original) + caller_function: str | None = None # Name of the function containing this reference + + +@dataclass +class ExportedFunction: + """Represents how a function is exported from its source file.""" + + function_name: str # The local function name + export_name: str | None # The name it's exported as (may differ) + is_default: bool # Whether it's a default export + file_path: Path # The source file + + +@dataclass +class ReferenceSearchContext: + """Context for tracking visited files during reference search.""" + + visited_files: set[Path] = field(default_factory=set) + max_files: int = 1000 # Limit to prevent runaway searches + + +class ReferenceFinder: + """Finds all references to a function across a JavaScript/TypeScript codebase. + + This class provides functionality similar to Jedi's find_references for Python, + but for JavaScript/TypeScript using tree-sitter. + + Example usage: + ```python + from codeflash.languages.javascript.find_references import ReferenceFinder + + finder = ReferenceFinder(project_root=Path("/my/project")) + references = finder.find_references( + function_name="myHelper", + source_file=Path("/my/project/src/utils.ts") + ) + for ref in references: + print(f"{ref.file_path}:{ref.line} - {ref.context}") + ``` + """ + + # File extensions to search + EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs") + + def __init__(self, project_root: Path, exclude_patterns: list[str] | None = None) -> None: + """Initialize the ReferenceFinder. + + Args: + project_root: Root directory of the project to search. + exclude_patterns: Glob patterns of directories/files to exclude. + Defaults to ['node_modules', 'dist', 'build', '.git']. + + """ + self.project_root = project_root + self.exclude_patterns = exclude_patterns or ["node_modules", "dist", "build", ".git", "coverage", "__pycache__"] + self._file_cache: dict[Path, str] = {} + + def find_references( + self, + function_name: str, + source_file: Path, + include_definition: bool = False, + max_files: int = 1000, + ) -> list[Reference]: + """Find all references to a function across the project. + + Args: + function_name: Name of the function to find references for. + source_file: Path to the file where the function is defined. + include_definition: Whether to include the function definition itself. + max_files: Maximum number of files to search (prevents runaway searches). + + Returns: + List of Reference objects describing each call site. + + """ + from codeflash.languages.treesitter_utils import get_analyzer_for_file + + references: list[Reference] = [] + context = ReferenceSearchContext(max_files=max_files) + + # Step 1: Analyze how the function is exported from its source file + source_code = self._read_file(source_file) + if source_code is None: + logger.warning("Could not read source file: %s", source_file) + return references + + analyzer = get_analyzer_for_file(source_file) + exported = self._analyze_exports(function_name, source_file, source_code, analyzer) + + if not exported: + logger.debug("Function %s is not exported from %s", function_name, source_file) + # Still search in same file for internal references + same_file_refs = self._find_references_in_file( + source_file, source_code, function_name, None, analyzer, include_self=not include_definition + ) + references.extend(same_file_refs) + return references + + # Step 2: Find all files that might import from the source file + context.visited_files.add(source_file) + + # Track files that re-export our function (we'll search for imports to these too) + reexport_files: list[tuple[Path, str]] = [] # (file_path, export_name) + + # Step 3: Search all project files for imports and calls + # We use a separate set for files checked for re-exports to avoid duplicate work + checked_for_reexports: set[Path] = set() + + for file_path in self._iter_project_files(): + if file_path in context.visited_files: + continue + if len(context.visited_files) >= context.max_files: + logger.warning("Reached max file limit (%d), stopping search", max_files) + break + + file_code = self._read_file(file_path) + if file_code is None: + continue + + file_analyzer = get_analyzer_for_file(file_path) + + # Check if this file imports from the source file + imports = file_analyzer.find_imports(file_code) + import_info = self._find_matching_import(imports, source_file, file_path, exported) + + if import_info: + # Found an import - mark as visited and search for calls + context.visited_files.add(file_path) + import_name, original_import = import_info + file_refs = self._find_references_in_file( + file_path, file_code, function_name, import_name, file_analyzer, include_self=True + ) + references.extend(file_refs) + + # Always check for re-exports (even without direct import match) + # This handles the case where a file re-exports from our source file + if file_path not in checked_for_reexports: + checked_for_reexports.add(file_path) + reexport_refs = self._find_reexports_direct( + file_path, file_code, source_file, exported, file_analyzer + ) + references.extend(reexport_refs) + + # Track re-export files for later searching + for ref in reexport_refs: + reexport_files.append((file_path, ref.import_name)) + + # Step 4: Follow re-export chains to find references through re-exports + for reexport_file, reexport_name in reexport_files: + # Create a new ExportedFunction for the re-exported function + reexported = ExportedFunction( + function_name=reexport_name, + export_name=reexport_name, + is_default=False, + file_path=reexport_file, + ) + + # Search for imports to the re-export file + for file_path in self._iter_project_files(): + if file_path in context.visited_files: + continue + if file_path == reexport_file: + continue + if len(context.visited_files) >= context.max_files: + break + + file_code = self._read_file(file_path) + if file_code is None: + continue + + file_analyzer = get_analyzer_for_file(file_path) + imports = file_analyzer.find_imports(file_code) + + # Check if this file imports from the re-export file + import_info = self._find_matching_import(imports, reexport_file, file_path, reexported) + + if import_info: + context.visited_files.add(file_path) + import_name, original_import = import_info + file_refs = self._find_references_in_file( + file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True + ) + # Avoid duplicates + existing_locs = {(r.file_path, r.line, r.column) for r in references} + for ref in file_refs: + if (ref.file_path, ref.line, ref.column) not in existing_locs: + references.append(ref) + + # Step 5: Include references in the same file (internal calls) + if include_definition or not exported: + same_file_refs = self._find_references_in_file( + source_file, source_code, function_name, None, analyzer, include_self=True + ) + # Filter out duplicate references + existing_locs = {(r.file_path, r.line, r.column) for r in references} + for ref in same_file_refs: + if (ref.file_path, ref.line, ref.column) not in existing_locs: + references.append(ref) + + # Step 6: Deduplicate references (same file, line, column) + seen: set[tuple[Path, int, int]] = set() + unique_refs: list[Reference] = [] + for ref in references: + key = (ref.file_path, ref.line, ref.column) + if key not in seen: + seen.add(key) + unique_refs.append(ref) + + return unique_refs + + def _analyze_exports( + self, function_name: str, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer + ) -> ExportedFunction | None: + """Analyze how a function is exported from its file. + + Args: + function_name: Name of the function to check. + file_path: Path to the source file. + source_code: Source code content. + analyzer: TreeSitterAnalyzer instance. + + Returns: + ExportedFunction if the function is exported, None otherwise. + + """ + is_exported, export_name = analyzer.is_function_exported(source_code, function_name) + + if not is_exported: + return None + + return ExportedFunction( + function_name=function_name, + export_name=export_name, + is_default=(export_name == "default"), + file_path=file_path, + ) + + def _find_matching_import( + self, + imports: list[ImportInfo], + source_file: Path, + importing_file: Path, + exported: ExportedFunction, + ) -> tuple[str, ImportInfo] | None: + """Find if any import in a file imports the target function. + + Args: + imports: List of imports in the file. + source_file: Path to the file containing the function definition. + importing_file: Path to the file being checked for imports. + exported: Information about how the function is exported. + + Returns: + Tuple of (imported_name, ImportInfo) if found, None otherwise. + + """ + from codeflash.languages.javascript.import_resolver import ImportResolver + + resolver = ImportResolver(self.project_root) + + for imp in imports: + # Resolve the import to see if it points to our source file + resolved = resolver.resolve_import(imp, importing_file) + if resolved is None: + continue + + if resolved.file_path != source_file: + continue + + # This import is from our source file - check if it imports our function + if exported.is_default: + # Default export - check default import + if imp.default_import: + return (imp.default_import, imp) + # Also check namespace import + if imp.namespace_import: + return (f"{imp.namespace_import}.default", imp) + else: + # Named export - check named imports + export_name = exported.export_name or exported.function_name + for name, alias in imp.named_imports: + if name == export_name: + return (alias if alias else name, imp) + + # Check namespace import + if imp.namespace_import: + return (f"{imp.namespace_import}.{export_name}", imp) + + # Handle CommonJS default import used as namespace + # e.g., const helpers = require('./helpers'); helpers.processConfig() + # In this case, default_import acts like a namespace + if imp.default_import and not imp.named_imports: + return (f"{imp.default_import}.{export_name}", imp) + + return None + + def _find_references_in_file( + self, + file_path: Path, + source_code: str, + function_name: str, + import_name: str | None, + analyzer: TreeSitterAnalyzer, + include_self: bool = True, + ) -> list[Reference]: + """Find all references to a function within a single file. + + Args: + file_path: Path to the file to search. + source_code: Source code content. + function_name: Original function name. + import_name: Name the function is imported as (may be different). + analyzer: TreeSitterAnalyzer instance. + include_self: Whether to include references in the file. + + Returns: + List of Reference objects. + + """ + references: list[Reference] = [] + source_bytes = source_code.encode("utf8") + tree = analyzer.parse(source_bytes) + lines = source_code.splitlines() + + # The name to search for (either imported name or original) + search_name = import_name if import_name else function_name + + # Handle namespace imports (e.g., "utils.helper") + if "." in search_name: + namespace, member = search_name.split(".", 1) + self._find_member_calls( + tree.root_node, source_bytes, lines, file_path, namespace, member, references, None + ) + else: + # Find direct calls and other reference types + self._find_identifier_references( + tree.root_node, source_bytes, lines, file_path, search_name, function_name, references, None + ) + + return references + + def _find_identifier_references( + self, + node: Node, + source_bytes: bytes, + lines: list[str], + file_path: Path, + search_name: str, + original_name: str, + references: list[Reference], + current_function: str | None, + ) -> None: + """Recursively find references to an identifier in the AST. + + Args: + node: Current tree-sitter node. + source_bytes: Source code as bytes. + lines: Source code split into lines. + file_path: Path to the file. + search_name: Name to search for. + original_name: Original function name. + references: List to append references to. + current_function: Name of the containing function (for context). + + """ + from codeflash.languages.treesitter_utils import TreeSitterAnalyzer + + # Track current function context + new_current_function = current_function + if node.type in ("function_declaration", "method_definition"): + name_node = node.child_by_field_name("name") + if name_node: + new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") + elif node.type in ("variable_declarator",): + # Arrow function or function expression assigned to variable + name_node = node.child_by_field_name("name") + value_node = node.child_by_field_name("value") + if name_node and value_node and value_node.type in ("arrow_function", "function_expression"): + new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") + + # Check for call expressions + if node.type == "call_expression": + func_node = node.child_by_field_name("function") + if func_node and func_node.type == "identifier": + name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8") + if name == search_name: + ref = self._create_reference( + file_path, func_node, lines, "call", search_name, current_function + ) + references.append(ref) + + # Check for identifiers used as callbacks or passed as arguments + elif node.type == "identifier": + name = source_bytes[node.start_byte : node.end_byte].decode("utf8") + if name == search_name: + parent = node.parent + # Determine reference type based on context + ref_type = self._determine_reference_type(node, parent, source_bytes) + if ref_type: + ref = self._create_reference( + file_path, node, lines, ref_type, search_name, current_function + ) + references.append(ref) + + # Recurse into children + for child in node.children: + self._find_identifier_references( + child, source_bytes, lines, file_path, search_name, original_name, references, new_current_function + ) + + def _find_member_calls( + self, + node: Node, + source_bytes: bytes, + lines: list[str], + file_path: Path, + namespace: str, + member: str, + references: list[Reference], + current_function: str | None, + ) -> None: + """Find calls to namespace.member (e.g., utils.helper()). + + Args: + node: Current tree-sitter node. + source_bytes: Source code as bytes. + lines: Source code split into lines. + file_path: Path to the file. + namespace: The namespace/object name. + member: The member/property name. + references: List to append references to. + current_function: Name of the containing function. + + """ + # Track current function context + new_current_function = current_function + if node.type in ("function_declaration", "method_definition"): + name_node = node.child_by_field_name("name") + if name_node: + new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") + + # Check for call expressions with member access + if node.type == "call_expression": + func_node = node.child_by_field_name("function") + if func_node and func_node.type == "member_expression": + obj_node = func_node.child_by_field_name("object") + prop_node = func_node.child_by_field_name("property") + + if obj_node and prop_node: + obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8") + prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8") + + if obj_name == namespace and prop_name == member: + ref = self._create_reference( + file_path, func_node, lines, "call", f"{namespace}.{member}", current_function + ) + references.append(ref) + + # Also check for member expression used as callback + elif node.type == "member_expression": + obj_node = node.child_by_field_name("object") + prop_node = node.child_by_field_name("property") + + if obj_node and prop_node: + obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8") + prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8") + + if obj_name == namespace and prop_name == member: + parent = node.parent + if parent and parent.type != "call_expression": + ref_type = self._determine_reference_type(node, parent, source_bytes) + if ref_type: + ref = self._create_reference( + file_path, node, lines, ref_type, f"{namespace}.{member}", current_function + ) + references.append(ref) + + # Recurse into children + for child in node.children: + self._find_member_calls( + child, source_bytes, lines, file_path, namespace, member, references, new_current_function + ) + + def _determine_reference_type(self, node: Node, parent: Node | None, source_bytes: bytes) -> str | None: + """Determine the type of reference based on AST context. + + Args: + node: The identifier node. + parent: The parent node. + source_bytes: Source code as bytes. + + Returns: + Reference type string or None if this isn't a valid reference. + + """ + if parent is None: + return None + + # Skip import statements + if parent.type in ("import_specifier", "import_clause", "named_imports"): + return None + + # Skip function declarations (the function name itself) + if parent.type in ("function_declaration", "method_definition"): + name_node = parent.child_by_field_name("name") + if name_node and name_node.id == node.id: + return None + + # Skip variable declarations where this is being defined + if parent.type == "variable_declarator": + name_node = parent.child_by_field_name("name") + if name_node and name_node.id == node.id: + return None + + # Skip export specifiers + if parent.type == "export_specifier": + return None + + # Check if passed as argument (callback or memoized) + if parent.type == "arguments": + # Check if grandparent is a memoize call + grandparent = parent.parent + if grandparent and grandparent.type == "call_expression": + func_node = grandparent.child_by_field_name("function") + if func_node: + func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8") + if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]): + return "memoized" + return "callback" + + # Check if used in array (often callback patterns) + if parent.type == "array": + return "callback" + + # Check if passed to memoize/memoization functions (direct call check) + if parent.type == "call_expression": + func_node = parent.child_by_field_name("function") + if func_node: + func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8") + if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]): + return "memoized" + + # Check if used in a call expression as the function + if parent.type == "call_expression": + func_node = parent.child_by_field_name("function") + if func_node and func_node.id == node.id: + return "call" + + # Check if assigned to a property + if parent.type in ("pair", "property"): + return "property" + + # Check if part of member expression (method call setup) + if parent.type == "member_expression": + obj_node = parent.child_by_field_name("object") + if obj_node and obj_node.id == node.id: + # This is the object in obj.method + return None # We'll catch the actual call elsewhere + + # Generic reference + return "reference" + + def _create_reference( + self, + file_path: Path, + node: Node, + lines: list[str], + ref_type: str, + import_name: str, + caller_function: str | None, + ) -> Reference: + """Create a Reference object from a node. + + Args: + file_path: Path to the file. + node: The tree-sitter node. + lines: Source code lines. + ref_type: Type of reference. + import_name: Name the function was imported as. + caller_function: Name of the containing function. + + Returns: + A Reference object. + + """ + line_num = node.start_point[0] + 1 # 1-indexed + context = lines[node.start_point[0]] if node.start_point[0] < len(lines) else "" + + return Reference( + file_path=file_path, + line=line_num, + column=node.start_point[1], + end_line=node.end_point[0] + 1, + end_column=node.end_point[1], + context=context.strip(), + reference_type=ref_type, + import_name=import_name, + caller_function=caller_function, + ) + + def _find_reexports( + self, + file_path: Path, + source_code: str, + exported: ExportedFunction, + analyzer: TreeSitterAnalyzer, + context: ReferenceSearchContext, + ) -> list[Reference]: + """Find re-exports of the function. + + Re-exports look like: export { helper } from './utils' + + Args: + file_path: Path to the file being checked. + source_code: Source code content. + exported: Information about the original export. + analyzer: TreeSitterAnalyzer instance. + context: Search context. + + Returns: + List of Reference objects for re-exports. + + """ + references: list[Reference] = [] + exports = analyzer.find_exports(source_code) + lines = source_code.splitlines() + + for exp in exports: + if not exp.is_reexport: + continue + + # Check if this re-exports our function + export_name = exported.export_name or exported.function_name + for name, alias in exp.exported_names: + if name == export_name: + # This is a re-export of our function + # Create a reference with the line info from the export + context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else "" + ref = Reference( + file_path=file_path, + line=exp.start_line, + column=0, + end_line=exp.end_line, + end_column=0, + context=context_line.strip(), + reference_type="reexport", + import_name=alias if alias else name, + caller_function=None, + ) + references.append(ref) + + return references + + def _find_reexports_direct( + self, + file_path: Path, + source_code: str, + source_file: Path, + exported: ExportedFunction, + analyzer: TreeSitterAnalyzer, + ) -> list[Reference]: + """Find re-exports that directly reference our source file. + + This method checks if a file has re-export statements that + reference our source file. + + Args: + file_path: Path to the file being checked. + source_code: Source code content. + source_file: The original source file we're looking for references to. + exported: Information about the original export. + analyzer: TreeSitterAnalyzer instance. + + Returns: + List of Reference objects for re-exports. + + """ + from codeflash.languages.javascript.import_resolver import ImportResolver + + references: list[Reference] = [] + exports = analyzer.find_exports(source_code) + lines = source_code.splitlines() + resolver = ImportResolver(self.project_root) + + for exp in exports: + if not exp.is_reexport or not exp.reexport_source: + continue + + # Create a fake ImportInfo to resolve the re-export source + from codeflash.languages.treesitter_utils import ImportInfo + + fake_import = ImportInfo( + module_path=exp.reexport_source, + default_import=None, + named_imports=[], + namespace_import=None, + is_type_only=False, + start_line=exp.start_line, + end_line=exp.end_line, + ) + + resolved = resolver.resolve_import(fake_import, file_path) + if resolved is None or resolved.file_path != source_file: + continue + + # This file re-exports from our source file + export_name = exported.export_name or exported.function_name + for name, alias in exp.exported_names: + if name == export_name: + context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else "" + ref = Reference( + file_path=file_path, + line=exp.start_line, + column=0, + end_line=exp.end_line, + end_column=0, + context=context_line.strip(), + reference_type="reexport", + import_name=alias if alias else name, + caller_function=None, + ) + references.append(ref) + + return references + + def _iter_project_files(self) -> list[Path]: + """Iterate over all JavaScript/TypeScript files in the project. + + Returns: + List of file paths to search. + + """ + files: list[Path] = [] + + for ext in self.EXTENSIONS: + for file_path in self.project_root.rglob(f"*{ext}"): + # Check exclusion patterns + if self._should_exclude(file_path): + continue + files.append(file_path) + + return files + + def _should_exclude(self, file_path: Path) -> bool: + """Check if a file should be excluded from search. + + Args: + file_path: Path to check. + + Returns: + True if the file should be excluded. + + """ + path_str = str(file_path) + for pattern in self.exclude_patterns: + if pattern in path_str: + return True + return False + + def _read_file(self, file_path: Path) -> str | None: + """Read a file's contents with caching. + + Args: + file_path: Path to the file. + + Returns: + File contents or None if unreadable. + + """ + if file_path in self._file_cache: + return self._file_cache[file_path] + + try: + content = file_path.read_text(encoding="utf-8") + self._file_cache[file_path] = content + return content + except Exception as e: + logger.debug("Could not read file %s: %s", file_path, e) + return None + + +def find_references( + function_name: str, + source_file: Path, + project_root: Path | None = None, + max_files: int = 1000, +) -> list[Reference]: + """Convenience function to find all references to a function. + + This is a simple wrapper around ReferenceFinder for common use cases. + + Args: + function_name: Name of the function to find references for. + source_file: Path to the file where the function is defined. + project_root: Root directory of the project. If None, uses source_file's parent. + max_files: Maximum number of files to search. + + Returns: + List of Reference objects describing each call site. + + Example: + ```python + from pathlib import Path + from codeflash.languages.javascript.find_references import find_references + + refs = find_references( + function_name="myHelper", + source_file=Path("/my/project/src/utils.ts"), + project_root=Path("/my/project") + ) + for ref in refs: + print(f"{ref.file_path}:{ref.line}:{ref.column} - {ref.reference_type}") + ``` + + """ + if project_root is None: + project_root = source_file.parent + + finder = ReferenceFinder(project_root) + return finder.find_references(function_name, source_file, max_files=max_files) diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index d1eb5d3ad..51e1c2342 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -19,6 +19,7 @@ HelperFunction, Language, ParentInfo, + ReferenceInfo, TestInfo, TestResult, ) @@ -964,6 +965,66 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l logger.warning("Failed to find helpers for %s: %s", function.name, e) return [] + def find_references( + self, + function: FunctionInfo, + project_root: Path, + tests_root: Path | None = None, + max_files: int = 500, + ) -> list[ReferenceInfo]: + """Find all references (call sites) to a function across the codebase. + + Uses tree-sitter to find all places where a JavaScript/TypeScript function + is called, including direct calls, callbacks, memoized versions, and re-exports. + + Args: + function: The function to find references for. + project_root: Root of the project to search. + tests_root: Root of tests directory (references in tests are excluded). + max_files: Maximum number of files to search. + + Returns: + List of ReferenceInfo objects describing each reference location. + + """ + from codeflash.languages.base import ReferenceInfo + from codeflash.languages.javascript.find_references import ReferenceFinder + + try: + finder = ReferenceFinder(project_root) + refs = finder.find_references(function.name, function.file_path, max_files=max_files) + + # Convert to ReferenceInfo and filter out tests + result: list[ReferenceInfo] = [] + for ref in refs: + # Exclude test files if tests_root is provided + if tests_root: + try: + ref.file_path.relative_to(tests_root) + continue # Skip if in tests_root + except ValueError: + pass # Not in tests_root, include it + + result.append( + ReferenceInfo( + file_path=ref.file_path, + line=ref.line, + column=ref.column, + end_line=ref.end_line, + end_column=ref.end_column, + context=ref.context, + reference_type=ref.reference_type, + import_name=ref.import_name, + caller_function=ref.caller_function, + ) + ) + + return result + + except Exception as e: + logger.warning("Failed to find references for %s: %s", function.name, e) + return [] + # === Code Transformation === def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str: diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index df086d2fa..bf35c7777 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -13,6 +13,7 @@ HelperFunction, Language, ParentInfo, + ReferenceInfo, TestInfo, TestResult, ) @@ -294,6 +295,120 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l return helpers + def find_references( + self, + function: FunctionInfo, + project_root: Path, + tests_root: Path | None = None, + max_files: int = 500, + ) -> list[ReferenceInfo]: + """Find all references (call sites) to a function across the codebase. + + Uses jedi to find all places where a Python function is called. + + Args: + function: The function to find references for. + project_root: Root of the project to search. + tests_root: Root of tests directory (references in tests are excluded). + max_files: Maximum number of files to search. + + Returns: + List of ReferenceInfo objects describing each reference location. + + """ + try: + import jedi + + source = function.file_path.read_text() + + # Find the function position + script = jedi.Script(code=source, path=function.file_path) + names = script.get_names(all_scopes=True, definitions=True) + + function_pos = None + for name in names: + if name.type == "function" and name.name == function.name: + # Check for class parent if it's a method + if function.class_name: + parent = name.parent() + if parent and parent.name == function.class_name and parent.type == "class": + function_pos = (name.line, name.column) + break + else: + function_pos = (name.line, name.column) + break + + if function_pos is None: + return [] + + # Get references using jedi + script = jedi.Script(code=source, path=function.file_path, project=jedi.Project(path=project_root)) + references = script.get_references(line=function_pos[0], column=function_pos[1]) + + result: list[ReferenceInfo] = [] + seen_locations: set[tuple[Path, int, int]] = set() + + for ref in references: + if not ref.module_path: + continue + + ref_path = Path(ref.module_path) + + # Skip the definition itself + if ref_path == function.file_path and ref.line == function_pos[0]: + continue + + # Skip test files + if tests_root: + try: + ref_path.relative_to(tests_root) + continue + except ValueError: + pass + + # Avoid duplicates + loc_key = (ref_path, ref.line, ref.column) + if loc_key in seen_locations: + continue + seen_locations.add(loc_key) + + # Get context line + try: + ref_source = ref_path.read_text() + lines = ref_source.splitlines() + context = lines[ref.line - 1] if ref.line <= len(lines) else "" + except Exception: + context = "" + + # Determine caller function + caller_function = None + try: + parent = ref.parent() + if parent and parent.type == "function": + caller_function = parent.name + except Exception: + pass + + result.append( + ReferenceInfo( + file_path=ref_path, + line=ref.line, + column=ref.column, + end_line=ref.line, + end_column=ref.column + len(function.name), + context=context.strip(), + reference_type="call", + import_name=function.name, + caller_function=caller_function, + ) + ) + + return result + + except Exception as e: + logger.warning("Failed to find references for %s: %s", function.name, e) + return [] + # === Code Transformation === def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str: diff --git a/tests/test_languages/test_find_references.py b/tests/test_languages/test_find_references.py new file mode 100644 index 000000000..8bbd5ce09 --- /dev/null +++ b/tests/test_languages/test_find_references.py @@ -0,0 +1,1088 @@ +"""Comprehensive tests for JavaScript/TypeScript find references functionality. + +These tests are inspired by real-world patterns found in the Appsmith codebase, +covering various import/export patterns, callback usage, memoization, and more. + +Each test verifies: +1. The actual reference values (file, line, column, type, caller) +2. The formatted markdown output from _format_references_as_markdown +""" + +import pytest +from pathlib import Path + +from codeflash.languages.javascript.find_references import ( + Reference, + ReferenceFinder, + ExportedFunction, + ReferenceSearchContext, + find_references, +) +from codeflash.languages.base import Language, FunctionInfo, ReferenceInfo +from codeflash.code_utils.code_extractor import _format_references_as_markdown + + +class TestReferenceFinder: + """Tests for ReferenceFinder class.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create a basic project structure.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + return tmp_path + + @pytest.fixture + def finder(self, project_root): + """Create a ReferenceFinder instance.""" + return ReferenceFinder(project_root) + + def test_init_default_exclude_patterns(self, project_root): + """Test that default exclude patterns are set.""" + finder = ReferenceFinder(project_root) + assert "node_modules" in finder.exclude_patterns + assert "dist" in finder.exclude_patterns + assert ".git" in finder.exclude_patterns + + def test_init_custom_exclude_patterns(self, project_root): + """Test custom exclude patterns.""" + finder = ReferenceFinder(project_root, exclude_patterns=["custom_dir"]) + assert "custom_dir" in finder.exclude_patterns + assert "node_modules" not in finder.exclude_patterns + + def test_should_exclude_node_modules(self, finder, project_root): + """Test that node_modules files are excluded.""" + path = project_root / "node_modules" / "lodash" / "index.js" + assert finder._should_exclude(path) is True + + def test_should_not_exclude_src(self, finder, project_root): + """Test that src files are not excluded.""" + path = project_root / "src" / "utils.ts" + assert finder._should_exclude(path) is False + + +class TestBasicNamedExports: + """Tests for basic named export/import patterns. + + Inspired by Appsmith patterns like: + import { getDynamicBindings, isDynamicValue } from "utils/DynamicBindingUtils"; + """ + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with named export pattern.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + utils_dir = src_dir / "utils" + utils_dir.mkdir() + + # Source file with named export + (utils_dir / "DynamicBindingUtils.ts").write_text( + 'export function getDynamicBindings(value: string): string[] {\n' + ' const regex = /{{([^}]+)}}/g;\n' + ' return [];\n' + '}\n' + ) + + # File that imports and uses the function + (src_dir / "evaluator.ts").write_text( + "import { getDynamicBindings } from './utils/DynamicBindingUtils';\n" + '\n' + 'export function evaluate(expression: string) {\n' + ' const bindings = getDynamicBindings(expression);\n' + ' return bindings;\n' + '}\n' + ) + + # Another file that uses the function + (src_dir / "validator.ts").write_text( + "import { getDynamicBindings } from './utils/DynamicBindingUtils';\n" + '\n' + 'export function validateBindings(input: string) {\n' + ' const bindings = getDynamicBindings(input);\n' + ' return bindings.length > 0;\n' + '}\n' + ) + + return tmp_path + + def test_find_named_export_references_values(self, project_root): + """Test finding references to a named exported function with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils" / "DynamicBindingUtils.ts" + + refs = finder.find_references("getDynamicBindings", source_file) + + # Sort refs by file path for consistent ordering + refs_sorted = sorted(refs, key=lambda r: (str(r.file_path), r.line)) + + # Should find 2 references + assert len(refs_sorted) == 2 + + # Check evaluator.ts reference + eval_ref = next(r for r in refs_sorted if "evaluator.ts" in str(r.file_path)) + assert eval_ref.line == 4 + assert eval_ref.reference_type == "call" + assert eval_ref.caller_function == "evaluate" + assert eval_ref.import_name == "getDynamicBindings" + assert "getDynamicBindings(expression)" in eval_ref.context + + # Check validator.ts reference + val_ref = next(r for r in refs_sorted if "validator.ts" in str(r.file_path)) + assert val_ref.line == 4 + assert val_ref.reference_type == "call" + assert val_ref.caller_function == "validateBindings" + assert val_ref.import_name == "getDynamicBindings" + assert "getDynamicBindings(input)" in val_ref.context + + def test_format_references_as_markdown_named_exports(self, project_root): + """Test _format_references_as_markdown output for named exports.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils" / "DynamicBindingUtils.ts" + + refs = finder.find_references("getDynamicBindings", source_file) + + # Convert to ReferenceInfo and sort for consistent ordering + ref_infos = sorted([ + ReferenceInfo( + file_path=r.file_path, + line=r.line, + column=r.column, + end_line=r.end_line, + end_column=r.end_column, + context=r.context, + reference_type=r.reference_type, + import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ], key=lambda r: str(r.file_path)) + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT) + + expected_markdown = ( + '```typescript:src/evaluator.ts\n' + 'function evaluate(expression: string) {\n' + ' const bindings = getDynamicBindings(expression);\n' + ' return bindings;\n' + '}\n' + '```\n' + '```typescript:src/validator.ts\n' + 'function validateBindings(input: string) {\n' + ' const bindings = getDynamicBindings(input);\n' + ' return bindings.length > 0;\n' + '}\n' + '```\n' + ) + assert markdown == expected_markdown + + +class TestDefaultExports: + """Tests for default export/import patterns.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with default export pattern.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + # Source file with default export + (src_dir / "helper.ts").write_text( + 'function processData(data: any[]) {\n' + ' return data.filter(item => item.active);\n' + '}\n' + '\n' + 'export default processData;\n' + ) + + # File that imports the default export + (src_dir / "main.ts").write_text( + "import processData from './helper';\n" + '\n' + 'export function handleData(items: any[]) {\n' + ' const processed = processData(items);\n' + ' return processed.length;\n' + '}\n' + ) + + # File that imports with a different name + (src_dir / "alternative.ts").write_text( + "import myProcessor from './helper';\n" + '\n' + 'export function process(items: any[]) {\n' + ' return myProcessor(items);\n' + '}\n' + ) + + return tmp_path + + def test_find_default_export_references_values(self, project_root): + """Test finding references to a default exported function with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "helper.ts" + + refs = finder.find_references("processData", source_file) + + # Should find references in both files + ref_files = {str(ref.file_path) for ref in refs} + assert any("main.ts" in f for f in ref_files) + assert any("alternative.ts" in f for f in ref_files) + + # Check main.ts reference (uses original name) + main_ref = next(r for r in refs if "main.ts" in str(r.file_path)) + assert main_ref.line == 4 + assert main_ref.reference_type == "call" + assert main_ref.caller_function == "handleData" + assert main_ref.import_name == "processData" + + # Check alternative.ts reference (uses alias) + alt_ref = next(r for r in refs if "alternative.ts" in str(r.file_path)) + assert alt_ref.line == 4 + assert alt_ref.reference_type == "call" + assert alt_ref.caller_function == "process" + assert alt_ref.import_name == "myProcessor" + + def test_format_references_as_markdown_default_exports(self, project_root): + """Test markdown output for default exports with aliases.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "helper.ts" + + refs = finder.find_references("processData", source_file) + ref_infos = sorted([ + ReferenceInfo( + file_path=r.file_path, line=r.line, column=r.column, + end_line=r.end_line, end_column=r.end_column, context=r.context, + reference_type=r.reference_type, import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ], key=lambda r: str(r.file_path)) + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT) + + expected_markdown = ( + '```typescript:src/alternative.ts\n' + 'function process(items: any[]) {\n' + ' return myProcessor(items);\n' + '}\n' + '```\n' + '```typescript:src/main.ts\n' + 'function handleData(items: any[]) {\n' + ' const processed = processData(items);\n' + ' return processed.length;\n' + '}\n' + '```\n' + ) + assert markdown == expected_markdown + + +class TestReExports: + """Tests for re-export patterns.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with re-export pattern.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + utils_dir = src_dir / "utils" + utils_dir.mkdir() + + # Original function file + (utils_dir / "filterUtils.ts").write_text( + 'export function filterBySearchTerm(items: any[], term: string) {\n' + ' return items.filter(i => i.name.includes(term));\n' + '}\n' + ) + + # Index file that re-exports + (utils_dir / "index.ts").write_text( + "export { filterBySearchTerm } from './filterUtils';\n" + ) + + # Consumer that imports from index + (src_dir / "consumer.ts").write_text( + "import { filterBySearchTerm } from './utils';\n" + '\n' + 'export function searchItems(items: any[], query: string) {\n' + ' return filterBySearchTerm(items, query);\n' + '}\n' + ) + + return tmp_path + + def test_find_reexport_reference_values(self, project_root): + """Test finding re-export references with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils" / "filterUtils.ts" + + refs = finder.find_references("filterBySearchTerm", source_file) + + # Should find re-export in index.ts + reexport_refs = [r for r in refs if r.reference_type == "reexport"] + assert len(reexport_refs) == 1 + assert "index.ts" in str(reexport_refs[0].file_path) + assert reexport_refs[0].import_name == "filterBySearchTerm" + + # Should find call in consumer.ts (through re-export chain) + call_refs = [r for r in refs if r.reference_type == "call"] + assert len(call_refs) >= 1 + consumer_ref = next((r for r in call_refs if "consumer.ts" in str(r.file_path)), None) + assert consumer_ref is not None + assert consumer_ref.line == 4 + assert consumer_ref.caller_function == "searchItems" + + def test_format_references_as_markdown_reexports(self, project_root): + """Test markdown output for re-exports.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils" / "filterUtils.ts" + + refs = finder.find_references("filterBySearchTerm", source_file) + ref_infos = sorted([ + ReferenceInfo( + file_path=r.file_path, line=r.line, column=r.column, + end_line=r.end_line, end_column=r.end_column, context=r.context, + reference_type=r.reference_type, import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ], key=lambda r: str(r.file_path)) + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT) + + expected_markdown = ( + '```typescript:src/consumer.ts\n' + 'function searchItems(items: any[], query: string) {\n' + ' return filterBySearchTerm(items, query);\n' + '}\n' + '```\n' + '```typescript:src/utils/index.ts\n' + "export { filterBySearchTerm } from './filterUtils';\n" + '```\n' + ) + assert markdown == expected_markdown + + +class TestCallbackPatterns: + """Tests for functions passed as callbacks.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with callback patterns.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + # Helper function + (src_dir / "transforms.ts").write_text( + 'export function normalizeItem(item: any) {\n' + ' return { ...item, normalized: true };\n' + '}\n' + ) + + # Consumer using callbacks + (src_dir / "processor.ts").write_text( + "import { normalizeItem } from './transforms';\n" + '\n' + 'export function processItems(items: any[]) {\n' + ' const normalized = items.map(normalizeItem);\n' + ' return normalized;\n' + '}\n' + ) + + return tmp_path + + def test_find_callback_references_values(self, project_root): + """Test finding functions used as callbacks with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "transforms.ts" + + refs = finder.find_references("normalizeItem", source_file) + + # Should find the callback reference + callback_refs = [r for r in refs if r.reference_type == "callback"] + assert len(callback_refs) >= 1 + + callback_ref = callback_refs[0] + assert "processor.ts" in str(callback_ref.file_path) + assert callback_ref.line == 4 + assert callback_ref.caller_function == "processItems" + assert "items.map(normalizeItem)" in callback_ref.context + + def test_format_references_as_markdown_callbacks(self, project_root): + """Test markdown output for callback patterns.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "transforms.ts" + + refs = finder.find_references("normalizeItem", source_file) + ref_infos = [ + ReferenceInfo( + file_path=r.file_path, line=r.line, column=r.column, + end_line=r.end_line, end_column=r.end_column, context=r.context, + reference_type=r.reference_type, import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ] + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT) + + expected_markdown = ( + '```typescript:src/processor.ts\n' + 'function processItems(items: any[]) {\n' + ' const normalized = items.map(normalizeItem);\n' + ' return normalized;\n' + '}\n' + '```\n' + ) + assert expected_markdown == markdown + + +class TestAliasImports: + """Tests for functions imported with aliases.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with alias import patterns.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + # Source file + (src_dir / "utils.ts").write_text( + 'export function computeValue(input: number): number {\n' + ' return input * 2;\n' + '}\n' + ) + + # File using alias + (src_dir / "consumer.ts").write_text( + "import { computeValue as calculate } from './utils';\n" + '\n' + 'export function processNumber(n: number) {\n' + ' const result = calculate(n);\n' + ' return result + 10;\n' + '}\n' + ) + + return tmp_path + + def test_find_aliased_import_references_values(self, project_root): + """Test finding references when function is imported with alias.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils.ts" + + refs = finder.find_references("computeValue", source_file) + + # Should find the reference even though it's called as "calculate" + assert len(refs) == 1 + ref = refs[0] + assert "consumer.ts" in str(ref.file_path) + assert ref.line == 4 + assert ref.reference_type == "call" + assert ref.import_name == "calculate" # The aliased name + assert ref.caller_function == "processNumber" + assert "calculate(n)" in ref.context + + def test_format_references_as_markdown_aliases(self, project_root): + """Test markdown output for aliased imports.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils.ts" + + refs = finder.find_references("computeValue", source_file) + ref_infos = [ + ReferenceInfo( + file_path=r.file_path, line=r.line, column=r.column, + end_line=r.end_line, end_column=r.end_column, context=r.context, + reference_type=r.reference_type, import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ] + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT) + + expected_markdown = ( + '```typescript:src/consumer.ts\n' + 'function processNumber(n: number) {\n' + ' const result = calculate(n);\n' + ' return result + 10;\n' + '}\n' + '```\n' + ) + assert expected_markdown == markdown + + +class TestNamespaceImports: + """Tests for namespace import patterns.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with namespace import patterns.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + # Source file with multiple exports + (src_dir / "mathUtils.ts").write_text( + 'export function add(a: number, b: number): number {\n' + ' return a + b;\n' + '}\n' + ) + + # File using namespace import + (src_dir / "calculator.ts").write_text( + "import * as MathUtils from './mathUtils';\n" + '\n' + 'export function calculate(a: number, b: number) {\n' + ' return MathUtils.add(a, b);\n' + '}\n' + ) + + return tmp_path + + def test_find_namespace_import_references_values(self, project_root): + """Test finding references via namespace imports with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "mathUtils.ts" + + refs = finder.find_references("add", source_file) + + assert len(refs) == 1 + ref = refs[0] + assert "calculator.ts" in str(ref.file_path) + assert ref.line == 4 + assert ref.reference_type == "call" + assert ref.import_name == "MathUtils.add" + assert ref.caller_function == "calculate" + assert "MathUtils.add(a, b)" in ref.context + + def test_format_references_as_markdown_namespace(self, project_root): + """Test markdown output for namespace imports.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "mathUtils.ts" + + refs = finder.find_references("add", source_file) + ref_infos = [ + ReferenceInfo( + file_path=r.file_path, line=r.line, column=r.column, + end_line=r.end_line, end_column=r.end_column, context=r.context, + reference_type=r.reference_type, import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ] + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT) + + expected_markdown = ( + '```typescript:src/calculator.ts\n' + 'function calculate(a: number, b: number) {\n' + ' return MathUtils.add(a, b);\n' + '}\n' + '```\n' + ) + assert expected_markdown == markdown + + +class TestMemoizedFunctions: + """Tests for memoized function patterns.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with memoized function patterns.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + # Source file with function to be memoized + (src_dir / "expensive.ts").write_text( + 'export function computeExpensive(x: number): number {\n' + ' return x * x;\n' + '}\n' + ) + + # File that memoizes the function + (src_dir / "memoized.ts").write_text( + "import memoize from 'micro-memoize';\n" + "import { computeExpensive } from './expensive';\n" + '\n' + 'export const memoizedCompute = memoize(computeExpensive);\n' + '\n' + 'export function process(x: number) {\n' + ' return computeExpensive(x) + memoizedCompute(x);\n' + '}\n' + ) + + return tmp_path + + def test_find_memoized_function_references_values(self, project_root): + """Test finding references to functions passed to memoize.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "expensive.ts" + + refs = finder.find_references("computeExpensive", source_file) + + # Should find memoize call and direct call + assert len(refs) >= 2 + + # Check for memoized reference + memo_refs = [r for r in refs if r.reference_type == "memoized"] + assert len(memo_refs) >= 1 + memo_ref = memo_refs[0] + assert "memoized.ts" in str(memo_ref.file_path) + assert "memoize(computeExpensive)" in memo_ref.context + + # Check for direct call + call_refs = [r for r in refs if r.reference_type == "call"] + assert len(call_refs) >= 1 + + +class TestSameFileReferences: + """Tests for references within the same file.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with same-file reference patterns.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + # File with internal references + (src_dir / "recursive.ts").write_text( + 'export function factorial(n: number): number {\n' + ' if (n <= 1) return 1;\n' + ' return n * factorial(n - 1);\n' + '}\n' + ) + + return tmp_path + + def test_find_recursive_references_values(self, project_root): + """Test finding recursive calls within same file with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "recursive.ts" + + refs = finder.find_references("factorial", source_file, include_definition=True) + + # Should find the recursive call + call_refs = [r for r in refs if r.reference_type == "call"] + assert len(call_refs) >= 1 + + recursive_ref = call_refs[0] + assert recursive_ref.line == 3 + assert recursive_ref.caller_function == "factorial" + assert "factorial(n - 1)" in recursive_ref.context + + +class TestComplexMultiFileScenarios: + """Tests for complex multi-file scenarios inspired by Appsmith.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create a complex multi-file project structure.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + (src_dir / "utils").mkdir() + (src_dir / "components").mkdir() + + # Core utility function + (src_dir / "utils" / "widgetUtils.ts").write_text( + 'export function isLargeWidget(type: string): boolean {\n' + " return ['TABLE', 'LIST'].includes(type);\n" + '}\n' + ) + + # Re-export from index + (src_dir / "utils" / "index.ts").write_text( + "export { isLargeWidget } from './widgetUtils';\n" + ) + + # Component using the function via re-export + (src_dir / "components" / "Widget.tsx").write_text( + "import { isLargeWidget } from '../utils';\n" + '\n' + 'export function Widget({ type }: { type: string }) {\n' + ' const isLarge = isLargeWidget(type);\n' + ' return isLarge;\n' + '}\n' + ) + + return tmp_path + + def test_find_all_references_across_codebase_values(self, project_root): + """Test finding all references to isLargeWidget with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils" / "widgetUtils.ts" + + refs = finder.find_references("isLargeWidget", source_file) + + # Should find re-export in index.ts + reexport_refs = [r for r in refs if r.reference_type == "reexport"] + assert len(reexport_refs) == 1 + assert "index.ts" in str(reexport_refs[0].file_path) + + # Should find call in Widget.tsx + call_refs = [r for r in refs if r.reference_type == "call"] + assert len(call_refs) >= 1 + widget_ref = next((r for r in call_refs if "Widget.tsx" in str(r.file_path)), None) + assert widget_ref is not None + assert widget_ref.line == 4 + assert widget_ref.caller_function == "Widget" + + def test_format_references_as_markdown_complex(self, project_root): + """Test markdown output for complex multi-file scenario.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "utils" / "widgetUtils.ts" + + refs = finder.find_references("isLargeWidget", source_file) + ref_infos = sorted([ + ReferenceInfo( + file_path=r.file_path, line=r.line, column=r.column, + end_line=r.end_line, end_column=r.end_column, context=r.context, + reference_type=r.reference_type, import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ], key=lambda r: str(r.file_path)) + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT) + + expected_markdown = ( + '```typescript:src/components/Widget.tsx\n' + 'function Widget({ type }: { type: string }) {\n' + ' const isLarge = isLargeWidget(type);\n' + ' return isLarge;\n' + '}\n' + '```\n' + '```typescript:src/utils/index.ts\n' + "export { isLargeWidget } from './widgetUtils';\n" + '```\n' + ) + assert markdown == expected_markdown + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project for edge case testing.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + return tmp_path + + def test_nonexistent_file(self, project_root): + """Test handling of nonexistent source file.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "nonexistent.ts" + + refs = finder.find_references("someFunction", source_file) + + assert refs == [] + + def test_non_exported_function(self, project_root): + """Test handling of non-exported function.""" + # Create a file with non-exported function + (project_root / "src" / "private.ts").write_text( + 'function internalHelper() {\n' + ' return 42;\n' + '}\n' + '\n' + 'export function publicFunction() {\n' + ' return internalHelper();\n' + '}\n' + ) + + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "private.ts" + + refs = finder.find_references("internalHelper", source_file) + + # Should only find internal reference + assert all(r.file_path == source_file for r in refs) + + def test_empty_file(self, project_root): + """Test handling of empty file.""" + (project_root / "src" / "empty.ts").write_text("") + + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "empty.ts" + + refs = finder.find_references("anything", source_file) + + assert refs == [] + + def test_format_references_empty_list(self, project_root): + """Test _format_references_as_markdown with empty list.""" + markdown = _format_references_as_markdown([], project_root / "src" / "file.ts", project_root, Language.TYPESCRIPT) + assert markdown == "" + + +class TestCommonJSPatterns: + """Tests for CommonJS require/module.exports patterns.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project with CommonJS patterns.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + # CommonJS module + (src_dir / "helpers.js").write_text( + 'function processConfig(config) {\n' + ' return { ...config, processed: true };\n' + '}\n' + '\n' + 'module.exports = { processConfig };\n' + ) + + # Consumer using destructured require + (src_dir / "main.js").write_text( + "const { processConfig } = require('./helpers');\n" + '\n' + 'function handleConfig(config) {\n' + ' return processConfig(config);\n' + '}\n' + '\n' + 'module.exports = handleConfig;\n' + ) + + return tmp_path + + def test_find_commonjs_references_values(self, project_root): + """Test finding CommonJS references with exact values.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "helpers.js" + + refs = finder.find_references("processConfig", source_file) + + assert len(refs) >= 1 + main_ref = next((r for r in refs if "main.js" in str(r.file_path)), None) + assert main_ref is not None + assert main_ref.line == 4 + assert main_ref.reference_type == "call" + assert main_ref.caller_function == "handleConfig" + + def test_format_references_as_markdown_commonjs(self, project_root): + """Test markdown output for CommonJS patterns.""" + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "helpers.js" + + refs = finder.find_references("processConfig", source_file) + ref_infos = sorted([ + ReferenceInfo( + file_path=r.file_path, line=r.line, column=r.column, + end_line=r.end_line, end_column=r.end_column, context=r.context, + reference_type=r.reference_type, import_name=r.import_name, + caller_function=r.caller_function, + ) + for r in refs + ], key=lambda r: str(r.file_path)) + + markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.JAVASCRIPT) + + expected_markdown = ( + '```javascript:src/main.js\n' + 'function handleConfig(config) {\n' + ' return processConfig(config);\n' + '}\n' + '```\n' + ) + assert markdown == expected_markdown + + +class TestConvenienceFunction: + """Tests for the find_references convenience function.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create a simple project.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + + (src_dir / "utils.ts").write_text( + 'export function helper() {\n' + ' return 42;\n' + '}\n' + ) + + (src_dir / "main.ts").write_text( + "import { helper } from './utils';\n" + '\n' + 'export function main() {\n' + ' return helper();\n' + '}\n' + ) + + return tmp_path + + def test_find_references_function_values(self, project_root): + """Test the find_references convenience function with exact values.""" + source_file = project_root / "src" / "utils.ts" + + refs = find_references("helper", source_file, project_root=project_root) + + assert len(refs) == 1 + ref = refs[0] + assert "main.ts" in str(ref.file_path) + assert ref.line == 4 + assert ref.reference_type == "call" + assert ref.caller_function == "main" + + +class TestReferenceDataclass: + """Tests for Reference dataclass.""" + + def test_reference_creation(self, tmp_path): + """Test creating a Reference object.""" + ref = Reference( + file_path=tmp_path / "test.ts", + line=10, + column=5, + end_line=10, + end_column=15, + context="const result = myFunction();", + reference_type="call", + import_name="myFunction", + caller_function="processData", + ) + + assert ref.line == 10 + assert ref.column == 5 + assert ref.end_line == 10 + assert ref.end_column == 15 + assert ref.reference_type == "call" + assert ref.import_name == "myFunction" + assert ref.caller_function == "processData" + assert ref.context == "const result = myFunction();" + + def test_reference_without_caller(self, tmp_path): + """Test Reference with no caller function.""" + ref = Reference( + file_path=tmp_path / "test.ts", + line=1, + column=0, + end_line=1, + end_column=10, + context="export { fn } from './module';", + reference_type="reexport", + import_name="fn", + ) + + assert ref.caller_function is None + + +class TestExportedFunctionDataclass: + """Tests for ExportedFunction dataclass.""" + + def test_exported_function_named(self, tmp_path): + """Test ExportedFunction for named export.""" + exp = ExportedFunction( + function_name="myHelper", + export_name="myHelper", + is_default=False, + file_path=tmp_path / "utils.ts", + ) + + assert exp.function_name == "myHelper" + assert exp.export_name == "myHelper" + assert exp.is_default is False + assert exp.file_path == tmp_path / "utils.ts" + + def test_exported_function_default(self, tmp_path): + """Test ExportedFunction for default export.""" + exp = ExportedFunction( + function_name="processData", + export_name="default", + is_default=True, + file_path=tmp_path / "processor.ts", + ) + + assert exp.function_name == "processData" + assert exp.is_default is True + assert exp.export_name == "default" + + +class TestReferenceSearchContext: + """Tests for ReferenceSearchContext dataclass.""" + + def test_context_defaults(self): + """Test default values for ReferenceSearchContext.""" + ctx = ReferenceSearchContext() + + assert ctx.visited_files == set() + assert ctx.max_files == 1000 + + def test_context_custom_max_files(self): + """Test custom max_files value.""" + ctx = ReferenceSearchContext(max_files=500) + + assert ctx.max_files == 500 + + +class TestEdgeCasesAdvanced: + """Advanced edge case tests to catch potential failures.""" + + @pytest.fixture + def project_root(self, tmp_path): + """Create project for edge case testing.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + return tmp_path + + def test_circular_import_handling(self, project_root): + """Test that circular imports don't cause infinite loops.""" + src_dir = project_root / "src" + + # Create circular import structure + (src_dir / "a.ts").write_text( + "import { funcB } from './b';\n" + '\n' + 'export function funcA() {\n' + ' return funcB() + 1;\n' + '}\n' + ) + + (src_dir / "b.ts").write_text( + "import { funcA } from './a';\n" + '\n' + 'export function funcB() {\n' + ' return 42;\n' + '}\n' + '\n' + 'export function callsA() {\n' + ' return funcA();\n' + '}\n' + ) + + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "a.ts" + + # Should not hang or crash + refs = finder.find_references("funcA", source_file) + + # Should find reference in b.ts + b_refs = [r for r in refs if "b.ts" in str(r.file_path)] + assert len(b_refs) >= 1 + assert b_refs[0].caller_function == "callsA" + + def test_syntax_error_graceful_handling(self, project_root): + """Test that syntax errors in files are handled gracefully.""" + src_dir = project_root / "src" + + (src_dir / "valid.ts").write_text( + 'export function validFunction() {\n' + ' return 42;\n' + '}\n' + ) + + # Create a file with syntax error + (src_dir / "invalid.ts").write_text( + "import { validFunction } from './valid';\n" + '\n' + 'export function broken( {\n' + ' return validFunction(\n' + '}\n' + ) + + finder = ReferenceFinder(project_root) + source_file = project_root / "src" / "valid.ts" + + # Should not crash + refs = finder.find_references("validFunction", source_file) + assert isinstance(refs, list)