Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 215 additions & 10 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,23 +1563,228 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
def get_opt_review_metrics(
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
) -> str:
if language != Language.PYTHON:
# TODO: {Claude} handle function refrences for other languages
return ""
"""Get function reference metrics for optimization review.

Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.

Args:
source_code: Source code of the file containing the function.
file_path: Path to the file.
qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
project_root: Root of the project.
tests_root: Root of the tests directory.
language: The programming language.

Returns:
Markdown-formatted string with code blocks showing calling functions.
"""
from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo
from codeflash.languages.registry import get_language_support

start_time = time.perf_counter()

try:
# Get the language support
lang_support = get_language_support(language)
if lang_support is None:
return ""

# Parse qualified name to get function name and class name
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
if len(qualified_name_split) == 1:
target_function, target_class = qualified_name_split[0], None
function_name, class_name = qualified_name_split[0], None
else:
target_function, target_class = qualified_name_split[1], qualified_name_split[0]
matches = get_fn_references_jedi(
source_code, file_path, project_root, target_function, target_class
) # jedi is not perfect, it doesn't capture aliased references
calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
function_name, class_name = qualified_name_split[1], qualified_name_split[0]

# Create a FunctionInfo for the function
# We don't have full line info here, so we'll use defaults
parents = ()
if class_name:
parents = (ParentInfo(name=class_name, type="ClassDef"),)

func_info = FunctionInfo(
name=function_name,
file_path=file_path,
start_line=1,
end_line=1,
parents=parents,
language=language,
)

# Find references using language support
references = lang_support.find_references(func_info, project_root, tests_root, max_files=500)

if not references:
return ""

# Format references as markdown code blocks
calling_fns_details = _format_references_as_markdown(
references, file_path, project_root, language
)

except Exception as e:
logger.debug(f"Error getting function references: {e}")
calling_fns_details = ""
logger.debug(f"Investigate {e}")

end_time = time.perf_counter()
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
return calling_fns_details


def _format_references_as_markdown(
references: list, file_path: Path, project_root: Path, language: Language
) -> str:
"""Format references as markdown code blocks with calling function code.

Args:
references: List of ReferenceInfo objects.
file_path: Path to the source file (to exclude).
project_root: Root of the project.
language: The programming language.

Returns:
Markdown-formatted string.
"""
# Group references by file
refs_by_file: dict[Path, list] = {}
for ref in references:
# Exclude the source file's definition/import references
if ref.file_path == file_path and ref.reference_type in ("import", "reexport"):
continue

if ref.file_path not in refs_by_file:
refs_by_file[ref.file_path] = []
refs_by_file[ref.file_path].append(ref)

fn_call_context = ""
context_len = 0

for ref_file, file_refs in refs_by_file.items():
if context_len > MAX_CONTEXT_LEN_REVIEW:
break

try:
path_relative = ref_file.relative_to(project_root)
except ValueError:
continue

# Get syntax highlighting language
ext = ref_file.suffix.lstrip(".")
if language == Language.PYTHON:
lang_hint = "python"
elif ext in ("ts", "tsx"):
lang_hint = "typescript"
else:
lang_hint = "javascript"

# Read the file to extract calling function context
try:
file_content = ref_file.read_text(encoding="utf-8")
lines = file_content.splitlines()
except Exception:
continue

# Get unique caller functions from this file
callers_seen: set[str] = set()
caller_contexts: list[str] = []

for ref in file_refs:
caller = ref.caller_function or "<module>"
if caller in callers_seen:
continue
callers_seen.add(caller)

# Extract context around the reference
if ref.caller_function:
# Try to extract the full calling function
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
if func_code:
caller_contexts.append(func_code)
context_len += len(func_code)
else:
# Module-level call - show a few lines of context
start_line = max(0, ref.line - 3)
end_line = min(len(lines), ref.line + 2)
context_code = "\n".join(lines[start_line:end_line])
caller_contexts.append(context_code)
context_len += len(context_code)

if caller_contexts:
fn_call_context += f"```{lang_hint}:{path_relative}\n"
fn_call_context += "\n".join(caller_contexts)
fn_call_context += "\n```\n"

return fn_call_context


def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
"""Extract the source code of a calling function.

Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is.
language: The programming language.

Returns:
Source code of the function, or None if not found.
"""
if language == Language.PYTHON:
return _extract_calling_function_python(source_code, function_name, ref_line)
else:
return _extract_calling_function_js(source_code, function_name, ref_line)


def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in Python."""
try:
import ast

tree = ast.parse(source_code)
lines = source_code.splitlines()

for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == function_name:
# Check if the reference line is within this function
start_line = node.lineno
end_line = node.end_lineno or start_line
if start_line <= ref_line <= end_line:
return "\n".join(lines[start_line - 1 : end_line])
Comment on lines +1744 to +1753
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 105% (1.05x) speedup for _extract_calling_function_python in codeflash/code_utils/code_extractor.py

⏱️ Runtime : 81.6 milliseconds 39.8 milliseconds (best of 156 runs)

📝 Explanation and details

The optimization achieves a 105% speedup (from 81.6ms to 39.8ms) by fundamentally changing how the AST is traversed to find the target function.

Key Performance Improvements:

  1. Pruned Tree Traversal: The original code uses ast.walk() which visits every node in the AST (~50,663 nodes in the profiler results). The optimized version uses an explicit stack with ast.iter_child_nodes() and prunes entire subtrees whose line ranges cannot contain ref_line. This dramatically reduces nodes visited (only ~1,271 nodes in the optimized version) - a 97.5% reduction in node visits.

  2. Lazy Line Splitting: The original code eagerly calls source_code.splitlines() upfront for every invocation. The optimized version only splits lines after finding a matching function, eliminating unnecessary string processing when returning None or for most iterations.

  3. Early Exit via Line Range Filtering: By checking node_lineno <= ref_line <= node_end_lineno before exploring children, the optimization avoids descending into AST branches that are guaranteed not to contain the reference line. This is especially effective when the target function is early in large files.

Why This Matters:

From the line profiler, the original code spent 79.2% of time (257ms) just iterating through ast.walk() and 6.8% (22ms) on isinstance checks across all nodes. The optimized version reduces this to 3.3% on child iteration and 0.1% on isinstance checks - spending most time on the unavoidable ast.parse() instead.

Test Case Performance:

  • Small functions (2-5 lines): 13-30% faster - modest gains from skipping line splitting
  • Medium files (50-200 functions): 14-24% faster - benefits from pruning irrelevant function subtrees
  • Large files (500+ functions): 88-125% faster - massive gains as pruning eliminates thousands of unnecessary node visits. For example, test_extraction_performance_with_many_lines improved from 54.9ms to 24.4ms (125% faster)
  • Decorator edge cases: 21-69% faster when ref_line is outside function bounds, as early returns avoid full traversal

The optimization is particularly effective for large codebases and scenarios where the target function appears early in the file, as demonstrated by the 2-10x speedup in large-scale tests.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 34 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from __future__ import annotations

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.code_extractor import \
    _extract_calling_function_python

def test_basic_simple_function():
    # Basic case: simple function spanning multiple lines should be returned wholly.
    source = (
        "def foo():\n"                # line 1: function definition
        "    x = 1\n"                 # line 2: body
        "    return x\n"             # line 3: body
        "\n"                         # line 4: blank
        "y = 2\n"                    # line 5: unrelated code
    )
    # Request with a reference line inside the function (line 2).
    codeflash_output = _extract_calling_function_python(source, "foo", 2); result = codeflash_output # 39.5μs -> 34.9μs (13.2% faster)

def test_async_function_extraction():
    # Ensure async functions are handled as well.
    source = (
        "async def afunc():\n"      # line 1: async function def
        "    await something()\n"   # line 2: body (syntactically valid for parsing)
        "    return 3\n"           # line 3: body
    )
    # Reference line inside the async function body.
    codeflash_output = _extract_calling_function_python(source, "afunc", 2); result = codeflash_output # 39.0μs -> 33.6μs (16.0% faster)

def test_nested_function_extraction():
    # Nested function: inner should be extracted when reference line falls inside it.
    source = (
        "def outer():\n"                # line 1
        "    a = 0\n"                   # line 2
        "    def inner():\n"            # line 3: inner def
        "        return 5\n"            # line 4: inner body
        "    return inner()\n"          # line 5
    )
    # Reference line inside the inner function (line 4).
    codeflash_output = _extract_calling_function_python(source, "inner", 4); result = codeflash_output # 48.5μs -> 41.1μs (18.2% faster)

def test_multiple_functions_same_name_selects_correct_one():
    # Two functions with the same name; selecting by reference line should pick the correct occurrence.
    source = (
        "def duplicate():\n"       # line 1 (first occurrence)
        "    a = 1\n"              # line 2
        "\n"                       # line 3
        "def duplicate():\n"       # line 4 (second occurrence)
        "    b = 2\n"              # line 5
        "\n"                       # line 6
    )
    # Reference line in the second function body (line 5).
    codeflash_output = _extract_calling_function_python(source, "duplicate", 5); result = codeflash_output # 38.8μs -> 31.9μs (21.6% faster)

def test_decorated_function_excludes_decorator_lines_and_requires_ref_in_def_or_body():
    # Decorators appear above the def. The AST node.lineno refers to the def line, so decorators are not included.
    source = (
        "@some_decorator\n"        # line 1 (decorator)
        "def decorated():\n"       # line 2 (def)
        "    return 10\n"          # line 3
    )
    # Reference line inside the function (line 3) should return lines 2..3 (decorator excluded).
    codeflash_output = _extract_calling_function_python(source, "decorated", 3); result_inside = codeflash_output # 31.3μs -> 25.9μs (20.9% faster)
    # Reference line on the decorator (line 1) is not considered inside the function, so expect None.
    codeflash_output = _extract_calling_function_python(source, "decorated", 1); result_decorator_line = codeflash_output # 23.7μs -> 14.1μs (68.8% faster)

def test_boundary_conditions_start_and_end_lines():
    # Function where reference line matches exactly the start or end lines should still return the function.
    source = (
        "def boundary():\n"         # line 1
        "    x = 9\n"               # line 2 (end)
    )
    # Reference equals start line.
    codeflash_output = _extract_calling_function_python(source, "boundary", 1); res_start = codeflash_output # 29.0μs -> 24.0μs (20.6% faster)
    # Reference equals end line.
    codeflash_output = _extract_calling_function_python(source, "boundary", 2); res_end = codeflash_output # 17.0μs -> 13.1μs (29.7% faster)
    # Reference one line before start should return None.
    codeflash_output = _extract_calling_function_python(source, "boundary", 0); res_before = codeflash_output # 18.9μs -> 9.92μs (90.6% faster)

def test_single_line_function_def_with_pass():
    # Single-line function (def with pass on same line) should be extracted; end_lineno may equal start_lineno.
    source = "def single(): pass\n"
    # Reference at the single line.
    codeflash_output = _extract_calling_function_python(source, "single", 1); result = codeflash_output # 25.0μs -> 19.7μs (27.0% faster)

def test_syntax_error_returns_none():
    # Invalid Python source should be caught by the function and return None.
    source = "def oops(:\n    pass\n"
    # Parsing raises SyntaxError internally, but function catches and returns None.
    codeflash_output = _extract_calling_function_python(source, "oops", 1); result = codeflash_output # 32.0μs -> 31.4μs (1.85% faster)

def test_function_name_not_found_returns_none():
    # When the requested function name does not exist in the source, result should be None.
    source = (
        "def exists():\n"
        "    return 1\n"
    )
    codeflash_output = _extract_calling_function_python(source, "missing", 2); result = codeflash_output # 31.9μs -> 30.1μs (5.89% faster)

def test_large_scale_many_functions():
    # Large-scale scenario: generate many small functions and ensure the target one is extracted correctly.
    # Keep total functions under 1000 as required; choose 200 functions for this test.
    count = 200
    parts = []
    for i in range(count):
        # Each function occupies 2 lines + 1 blank line => 3 lines per function.
        parts.append(f"def f{i}():\n")
        parts.append(f"    return {i}\n")
        parts.append("\n")
    source = "".join(parts)
    # Choose a specific function in the middle to extract.
    target_index = 150
    # Compute the line number where the return statement of f150 appears:
    # For each function before it, there are 3 lines. So start_line = target_index * 3 + 1
    start_line = target_index * 3 + 1
    return_line = start_line + 1
    # Request extraction with the return line as the reference.
    codeflash_output = _extract_calling_function_python(source, f"f{target_index}", return_line); result = codeflash_output # 1.26ms -> 1.02ms (23.5% faster)
    # Expect exactly the two lines that define the target function.
    expected = f"def f{target_index}():\n    return {target_index}"
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import ast

# imports
import pytest
from codeflash.code_utils.code_extractor import \
    _extract_calling_function_python

# unit tests

class TestBasicFunctionality:
    """Test basic functionality of _extract_calling_function_python."""

    def test_simple_function_extraction(self):
        """Test extracting a simple single-line function."""
        source_code = "def hello():\n    return 'world'\n"
        codeflash_output = _extract_calling_function_python(source_code, "hello", 1); result = codeflash_output # 30.9μs -> 25.5μs (21.3% faster)

    def test_simple_function_extraction_multi_line(self):
        """Test extracting a simple multi-line function."""
        source_code = "def add(a, b):\n    result = a + b\n    return result\n"
        codeflash_output = _extract_calling_function_python(source_code, "add", 2); result = codeflash_output # 40.7μs -> 35.2μs (15.7% faster)

    def test_function_with_parameters(self):
        """Test extracting a function with multiple parameters."""
        source_code = "def greet(name, greeting='Hello'):\n    return greeting + ' ' + name\n"
        codeflash_output = _extract_calling_function_python(source_code, "greet", 1); result = codeflash_output # 40.9μs -> 35.2μs (16.1% faster)

    def test_async_function_extraction(self):
        """Test extracting an async function."""
        source_code = "async def fetch_data():\n    return 'data'\n"
        codeflash_output = _extract_calling_function_python(source_code, "fetch_data", 1); result = codeflash_output # 29.0μs -> 24.3μs (19.4% faster)

    def test_function_with_docstring(self):
        """Test extracting a function that contains a docstring."""
        source_code = 'def documented():\n    """This is documented."""\n    return True\n'
        codeflash_output = _extract_calling_function_python(source_code, "documented", 2); result = codeflash_output # 31.1μs -> 26.2μs (18.8% faster)

    def test_function_with_nested_calls(self):
        """Test extracting a function that contains nested function calls."""
        source_code = "def outer():\n    return inner()\n"
        codeflash_output = _extract_calling_function_python(source_code, "outer", 1); result = codeflash_output # 30.2μs -> 25.3μs (19.7% faster)

    def test_function_at_beginning_of_file(self):
        """Test extracting a function at the very beginning of the source code."""
        source_code = "def first():\n    pass\n\ndef second():\n    pass\n"
        codeflash_output = _extract_calling_function_python(source_code, "first", 1); result = codeflash_output # 30.5μs -> 26.2μs (16.5% faster)

    def test_function_in_middle_of_file(self):
        """Test extracting a function in the middle of the source code."""
        source_code = "def first():\n    pass\n\ndef middle():\n    return 42\n\ndef last():\n    pass\n"
        codeflash_output = _extract_calling_function_python(source_code, "middle", 4); result = codeflash_output # 39.5μs -> 32.4μs (21.7% faster)

    def test_function_at_end_of_file(self):
        """Test extracting a function at the end of the source code."""
        source_code = "def first():\n    pass\n\ndef last():\n    return 99\n"
        codeflash_output = _extract_calling_function_python(source_code, "last", 4); result = codeflash_output # 34.2μs -> 27.4μs (24.6% faster)

    def test_function_with_complex_logic(self):
        """Test extracting a function with complex conditional logic."""
        source_code = (
            "def check(x):\n"
            "    if x > 0:\n"
            "        return 'positive'\n"
            "    else:\n"
            "        return 'non-positive'\n"
        )
        codeflash_output = _extract_calling_function_python(source_code, "check", 3); result = codeflash_output # 45.1μs -> 39.7μs (13.6% faster)

class TestLargeScaleScenarios:
    """Test performance and scalability with larger code samples."""

    def test_extraction_from_file_with_many_functions(self):
        """Test extracting a function from a file with many other functions."""
        # Create source code with 100 functions
        source_parts = []
        for i in range(100):
            source_parts.append(f"def func_{i}():\n    return {i}\n\n")
        source_code = "".join(source_parts)
        
        # Extract function in the middle
        codeflash_output = _extract_calling_function_python(source_code, "func_50", 505); result = codeflash_output # 950μs -> 503μs (88.9% faster)

    def test_extraction_from_file_with_large_function(self):
        """Test extracting a very large function with many lines."""
        # Create a function with 500 lines
        lines = ["def large_func():\n"]
        for i in range(500):
            lines.append(f"    x_{i} = {i}\n")
        lines.append("    return x_0\n")
        source_code = "".join(lines)
        
        codeflash_output = _extract_calling_function_python(source_code, "large_func", 250); result = codeflash_output # 1.46ms -> 1.41ms (3.91% faster)

    def test_extraction_with_deeply_nested_structures(self):
        """Test extracting a function with deeply nested code structures."""
        lines = ["def nested():\n"]
        indent = 1
        for i in range(20):
            lines.append("    " * indent + f"if x_{i}:\n")
            indent += 1
        lines.append("    " * indent + "return True\n")
        source_code = "".join(lines)
        
        codeflash_output = _extract_calling_function_python(source_code, "nested", 5); result = codeflash_output # 97.1μs -> 90.4μs (7.45% faster)

    def test_extraction_with_many_string_literals(self):
        """Test extracting a function containing many string literals."""
        lines = ["def many_strings():\n"]
        for i in range(200):
            lines.append(f'    s_{i} = "string_{i}"\n')
        lines.append("    return s_0\n")
        source_code = "".join(lines)
        
        codeflash_output = _extract_calling_function_python(source_code, "many_strings", 100); result = codeflash_output # 610μs -> 565μs (7.87% faster)

    def test_extraction_with_mixed_async_and_sync_functions(self):
        """Test extracting functions from code with both async and sync definitions."""
        lines = []
        for i in range(50):
            if i % 2 == 0:
                lines.append(f"def sync_{i}():\n    return {i}\n\n")
            else:
                lines.append(f"async def async_{i}():\n    return {i}\n\n")
        source_code = "".join(lines)
        
        codeflash_output = _extract_calling_function_python(source_code, "async_25", 76); result = codeflash_output # 310μs -> 270μs (14.6% faster)

    def test_extraction_with_large_source_file(self):
        """Test extracting from a source code file with significant total size."""
        # Create a file with 500 functions
        lines = []
        for i in range(500):
            lines.append(f"def func_{i}():\n")
            for j in range(5):
                lines.append(f"    value_{j} = {i * j}\n")
            lines.append("    return value_0\n\n")
        source_code = "".join(lines)
        
        codeflash_output = _extract_calling_function_python(source_code, "func_250", 1255); result = codeflash_output # 21.1ms -> 10.8ms (95.5% faster)

    def test_extraction_performance_with_many_lines(self):
        """Test that extraction performs reasonably with many total lines."""
        # Create a 10000+ line file
        lines = []
        for i in range(500):
            lines.append(f"def function_{i}():\n")
            for j in range(10):
                lines.append(f"    statement_{j} = {i} + {j}\n")
            lines.append("    return statement_0\n\n")
        source_code = "".join(lines)
        
        # Should still find and extract correctly
        codeflash_output = _extract_calling_function_python(source_code, "function_100", 1010); result = codeflash_output # 54.9ms -> 24.4ms (125% faster)

    def test_extraction_with_unicode_content(self):
        """Test extracting a function containing Unicode characters."""
        source_code = (
            "def unicode_func():\n"
            '    text = "Hello 世界 🌍"\n'
            "    return text\n"
        )
        codeflash_output = _extract_calling_function_python(source_code, "unicode_func", 2); result = codeflash_output # 46.8μs -> 38.2μs (22.5% faster)

    def test_extraction_with_multiple_decorators(self):
        """Test extracting a function with multiple decorators."""
        source_code = (
            "@decorator1\n"
            "@decorator2\n"
            "@decorator3\n"
            "def multi_decorated():\n"
            "    return 'result'\n"
        )
        codeflash_output = _extract_calling_function_python(source_code, "multi_decorated", 4); result = codeflash_output # 39.5μs -> 32.7μs (20.8% faster)

    def test_extraction_boundary_ref_line_equals_start(self):
        """Test when ref_line equals the function start line."""
        source_code = "def boundary():\n    x = 1\n    y = 2\n    return x + y\n"
        codeflash_output = _extract_calling_function_python(source_code, "boundary", 1); result = codeflash_output # 41.0μs -> 34.7μs (18.3% faster)

    def test_extraction_boundary_ref_line_equals_end(self):
        """Test when ref_line equals the function end line."""
        source_code = "def boundary():\n    x = 1\n    y = 2\n    return x + y\n"
        codeflash_output = _extract_calling_function_python(source_code, "boundary", 4); result = codeflash_output # 37.6μs -> 32.3μs (16.2% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1226-2026-02-01T20.48.02

Click to see suggested changes
Suggested change
lines = source_code.splitlines()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == function_name:
# Check if the reference line is within this function
start_line = node.lineno
end_line = node.end_lineno or start_line
if start_line <= ref_line <= end_line:
return "\n".join(lines[start_line - 1 : end_line])
# Use an explicit stack and prune subtrees whose lineno/end_lineno
# ranges do not include ref_line to avoid walking the whole tree.
stack = [tree]
while stack:
node = stack.pop()
# If node has concrete line range info and the ref_line lies outside it,
# skip exploring this subtree entirely.
node_lineno = getattr(node, "lineno", None)
node_end_lineno = getattr(node, "end_lineno", None)
if node_lineno is not None and node_end_lineno is not None:
if not (node_lineno <= ref_line <= (node_end_lineno or node_lineno)):
continue
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == function_name:
# Check if the reference line is within this function
start_line = node.lineno
end_line = node.end_lineno or start_line
if start_line <= ref_line <= end_line:
lines = source_code.splitlines()
return "\n".join(lines[start_line - 1 : end_line])
# Push children for further inspection
for child in ast.iter_child_nodes(node):
stack.append(child)

return None
except Exception:
return None


def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in JavaScript/TypeScript.

Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is (helps identify the right function).

Returns:
Source code of the function, or None if not found.
"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage

# Try TypeScript first, fall back to JavaScript
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
try:
analyzer = TreeSitterAnalyzer(lang)
functions = analyzer.find_functions(source_code, include_methods=True)

for func in functions:
if func.name == function_name:
# Check if the reference line is within this function
if func.start_line <= ref_line <= func.end_line:
return func.source_text
break
except Exception:
continue

return None
except Exception:
return None
54 changes: 54 additions & 0 deletions codeflash/languages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,37 @@ class FunctionFilterCriteria:
max_lines: int | None = None


@dataclass
class ReferenceInfo:
"""Information about a reference (call site) to a function.

This class captures information about where a function is called
from, including the file, line number, context, and caller function.

Attributes:
file_path: Path to the file containing the reference.
line: Line number (1-indexed).
column: Column number (0-indexed).
end_line: End line number (1-indexed).
end_column: End column number (0-indexed).
context: The line of code containing the reference.
reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport").
import_name: Name used to import the function (may differ from original).
caller_function: Name of the function containing this reference (or None for module-level).

"""

file_path: Path
line: int
column: int
end_line: int
end_column: int
context: str
reference_type: str
import_name: str | None
caller_function: str | None = None


@runtime_checkable
class LanguageSupport(Protocol):
"""Protocol defining what a language implementation must provide.
Expand Down Expand Up @@ -357,6 +388,29 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l
"""
...

def find_references(
self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500
) -> list[ReferenceInfo]:
"""Find all references (call sites) to a function across the codebase.

This method finds all places where a function is called, including:
- Direct calls
- Callbacks (passed to other functions)
- Memoized versions
- Re-exports

Args:
function: The function to find references for.
project_root: Root of the project to search.
tests_root: Root of tests directory (references in tests are excluded).
max_files: Maximum number of files to search.

Returns:
List of ReferenceInfo objects describing each reference location.

"""
...

# === Code Transformation ===

def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
Expand Down
Loading
Loading