Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions codeflash/languages/javascript/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,115 @@ def is_relevant_import(module_path: str) -> bool:
return test_code


def fix_import_path_for_test_location(
test_code: str, source_file_path: Path, test_file_path: Path, module_root: Path
) -> str:
"""Fix import paths in generated test code to be relative to test file location.

The AI may generate tests with import paths that are relative to the module root
(e.g., 'apps/web/app/file') instead of relative to where the test file is located
(e.g., '../../app/file'). This function fixes such imports.

Args:
test_code: The generated test code.
source_file_path: Absolute path to the source file being tested.
test_file_path: Absolute path to where the test file will be written.
module_root: Root directory of the module/project.

Returns:
Test code with corrected import paths.

"""
import os

# Calculate the correct relative import path from test file to source file
test_dir = test_file_path.parent
try:
correct_rel_path = os.path.relpath(source_file_path, test_dir)
correct_rel_path = correct_rel_path.replace("\\", "/")
# Remove file extension for JS/TS imports
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
if correct_rel_path.endswith(ext):
correct_rel_path = correct_rel_path[: -len(ext)]
break
# Ensure it starts with ./ or ../
if not correct_rel_path.startswith("."):
correct_rel_path = "./" + correct_rel_path
except ValueError:
# Can't compute relative path (different drives on Windows)
return test_code

# Try to compute what incorrect path the AI might have generated
# The AI often uses module_root-relative paths like 'apps/web/app/...'
try:
source_rel_to_module = os.path.relpath(source_file_path, module_root)
source_rel_to_module = source_rel_to_module.replace("\\", "/")
# Remove extension
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
if source_rel_to_module.endswith(ext):
source_rel_to_module = source_rel_to_module[: -len(ext)]
break
except ValueError:
return test_code

# Also check for project root-relative paths (including module_root in path)
try:
project_root = module_root.parent if module_root.name in ["src", "lib", "app", "web", "apps"] else module_root
source_rel_to_project = os.path.relpath(source_file_path, project_root)
source_rel_to_project = source_rel_to_project.replace("\\", "/")
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
if source_rel_to_project.endswith(ext):
source_rel_to_project = source_rel_to_project[: -len(ext)]
break
except ValueError:
source_rel_to_project = None

# Source file name (for matching module paths that end with the file name)
source_name = source_file_path.stem

# Patterns to find import statements
# ESM: import { func } from 'path' or import func from 'path'
esm_import_pattern = re.compile(r"(import\s+(?:{[^}]+}|\w+)\s+from\s+['\"])([^'\"]+)(['\"])")
# CommonJS: const { func } = require('path') or const func = require('path')
cjs_require_pattern = re.compile(
r"((?:const|let|var)\s+(?:{[^}]+}|\w+)\s*=\s*require\s*\(\s*['\"])([^'\"]+)(['\"])"
)

def should_fix_path(import_path: str) -> bool:
"""Check if this import path looks like it should point to our source file."""
# Skip relative imports that already look correct
if import_path.startswith(("./", "../")):
return False
# Skip package imports (no path separators or start with @)
if "/" not in import_path and "\\" not in import_path:
return False
if import_path.startswith("@") and "/" in import_path:
# Could be an alias like @/utils - skip these
return False
# Check if it looks like it points to our source file
if import_path == source_rel_to_module:
return True
if source_rel_to_project and import_path == source_rel_to_project:
return True
if import_path.endswith((source_name, "/" + source_name)):
return True
Comment on lines +993 to +995
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low risk, but worth noting: The endswith(source_name) match on the stem alone is quite broad. If the source file is named something common like utils.ts or index.ts, this could match unrelated imports (e.g., some/other/utils would also get rewritten to the relative path of this source file).

Consider either:

  • Adding a check that the import path shares more path segments with the source, or
  • Requiring a stricter match (e.g., matching the last 2 path segments)

return False

def fix_import(match: re.Match[str]) -> str:
"""Replace incorrect import path with correct relative path."""
prefix = match.group(1)
import_path = match.group(2)
suffix = match.group(3)

if should_fix_path(import_path):
logger.debug(f"Fixing import path: {import_path} -> {correct_rel_path}")
return f"{prefix}{correct_rel_path}{suffix}"
return match.group(0)

test_code = esm_import_pattern.sub(fix_import, test_code)
return cjs_require_pattern.sub(fix_import, test_code)


def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
"""Generate path for instrumented test file.

Expand Down
72 changes: 72 additions & 0 deletions codeflash/languages/javascript/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ def parse_jest_test_xml(
logger.debug(f"Found {marker_count} timing start markers in Jest stdout")
else:
logger.debug(f"No timing start markers found in Jest stdout (len={len(global_stdout)})")
# Check for END markers with duration (perf test markers)
end_marker_count = len(jest_end_pattern.findall(global_stdout))
if end_marker_count > 0:
logger.debug(
f"[PERF-DEBUG] Found {end_marker_count} END timing markers with duration in Jest stdout"
)
# Sample a few markers to verify loop indices
end_samples = list(jest_end_pattern.finditer(global_stdout))[:5]
for sample in end_samples:
groups = sample.groups()
logger.debug(f"[PERF-DEBUG] Sample END marker: loopIndex={groups[3]}, duration={groups[5]}")
else:
logger.debug("[PERF-DEBUG] No END markers with duration found in Jest stdout")
except (AttributeError, UnicodeDecodeError):
global_stdout = ""

Expand All @@ -197,6 +210,14 @@ def parse_jest_test_xml(
key = match.groups()[:5]
end_matches_dict[key] = match

# Debug: log suite-level END marker parsing for perf tests
if end_matches_dict:
# Get unique loop indices from the parsed END markers
loop_indices = sorted({int(k[3]) if k[3].isdigit() else 1 for k in end_matches_dict})
logger.debug(
f"[PERF-DEBUG] Suite {suite_count}: parsed {len(end_matches_dict)} END markers from suite_stdout, loop_index range: {min(loop_indices)}-{max(loop_indices)}"
)

# Also collect timing markers from testcase-level system-out (Vitest puts output at testcase level)
for tc in suite:
tc_system_out = tc._elem.find("system-out") # noqa: SLF001
Expand Down Expand Up @@ -327,6 +348,13 @@ def parse_jest_test_xml(
sanitized_test_name = re.sub(r"[!#: ()\[\]{}|\\/*?^$.+\-]", "_", test_name)
matching_starts = [m for m in start_matches if sanitized_test_name in m.group(2)]

# Debug: log which branch we're taking
logger.debug(
f"[FLOW-DEBUG] Testcase '{test_name[:50]}': "
f"total_start_matches={len(start_matches)}, matching_starts={len(matching_starts)}, "
f"total_end_matches={len(end_matches_dict)}"
)

# For performance tests (capturePerf), there are no START markers - only END markers with duration
# Check for END markers directly if no START markers found
matching_ends_direct = []
Expand All @@ -337,6 +365,28 @@ def parse_jest_test_xml(
# end_key is (module, testName, funcName, loopIndex, invocationId)
if len(end_key) >= 2 and sanitized_test_name in end_key[1]:
matching_ends_direct.append(end_match)
# Debug: log matching results for perf tests
if matching_ends_direct:
loop_indices = [int(m.groups()[3]) if m.groups()[3].isdigit() else 1 for m in matching_ends_direct]
logger.debug(
f"[PERF-MATCH] Testcase '{test_name[:40]}': matched {len(matching_ends_direct)} END markers, "
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
)
elif end_matches_dict:
# No matches but we have END markers - check why
sample_keys = list(end_matches_dict.keys())[:3]
logger.debug(
f"[PERF-MISMATCH] Testcase '{test_name[:40]}': no matches found. "
f"sanitized_test_name='{sanitized_test_name[:50]}', "
f"sample end_keys={[k[1][:30] if len(k) >= 2 else k for k in sample_keys]}"
)

# Log if we're skipping the matching_ends_direct branch
if matching_starts and end_matches_dict:
logger.debug(
f"[FLOW-SKIP] Testcase '{test_name[:40]}': has {len(matching_starts)} START markers, "
f"skipping {len(end_matches_dict)} END markers (behavior test mode)"
)

if not matching_starts and not matching_ends_direct:
# No timing markers found - use JUnit XML time attribute as fallback
Expand Down Expand Up @@ -373,11 +423,13 @@ def parse_jest_test_xml(
)
elif matching_ends_direct:
# Performance test format: process END markers directly (no START markers)
loop_indices_found = []
for end_match in matching_ends_direct:
groups = end_match.groups()
# groups: (module, testName, funcName, loopIndex, invocationId, durationNs)
func_name = groups[2]
loop_index = int(groups[3]) if groups[3].isdigit() else 1
loop_indices_found.append(loop_index)
line_id = groups[4]
try:
runtime = int(groups[5])
Expand All @@ -403,6 +455,12 @@ def parse_jest_test_xml(
stdout="",
)
)
if loop_indices_found:
logger.debug(
f"[LOOP-DEBUG] Testcase '{test_name}': processed {len(matching_ends_direct)} END markers, "
f"loop_index range: {min(loop_indices_found)}-{max(loop_indices_found)}, "
f"total results so far: {len(test_results.test_results)}"
)
else:
# Process each timing marker
for match in matching_starts:
Expand Down Expand Up @@ -454,5 +512,19 @@ def parse_jest_test_xml(
f"Jest XML parsing complete: {len(test_results.test_results)} results "
f"from {suite_count} suites, {testcase_count} testcases"
)
# Debug: show loop_index distribution for perf analysis
if test_results.test_results:
loop_indices = [r.loop_index for r in test_results.test_results]
unique_loop_indices = sorted(set(loop_indices))
min_idx, max_idx = min(unique_loop_indices), max(unique_loop_indices)
logger.debug(
f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, "
f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}"
)
if max_idx == 1 and len(loop_indices) > 1:
logger.warning(
f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. "
"Perf test markers may not have been parsed correctly."
)

return test_results
2 changes: 2 additions & 0 deletions codeflash/languages/javascript/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,13 +2134,15 @@ def run_benchmarking_tests(
from codeflash.languages.test_framework import get_js_test_framework_or_default

framework = test_framework or get_js_test_framework_or_default()
logger.debug("run_benchmarking_tests called with framework=%s", framework)

# Use JS-specific high max_loops - actual loop count is limited by target_duration
effective_max_loops = self.JS_BENCHMARKING_MAX_LOOPS

if framework == "vitest":
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests

logger.debug("Dispatching to run_vitest_benchmarking_tests")
return run_vitest_benchmarking_tests(
test_paths=test_paths,
test_env=test_env,
Expand Down
64 changes: 59 additions & 5 deletions codeflash/languages/javascript/vitest_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _ensure_codeflash_vitest_config(project_root: Path) -> Path | None:
logger.debug("Detected vitest workspace configuration - skipping custom config")
return None

codeflash_config_path = project_root / "codeflash.vitest.config.js"
codeflash_config_path = project_root / "codeflash.vitest.config.mjs"

# If already exists, use it
if codeflash_config_path.exists():
Expand Down Expand Up @@ -281,7 +281,7 @@ def _build_vitest_behavioral_command(

# For monorepos with restrictive vitest configs (e.g., include: test/**/*.test.ts),
# we need to create a custom config that allows all test patterns.
# This is done by creating a codeflash.vitest.config.js file.
# This is done by creating a codeflash.vitest.config.mjs file.
if project_root:
codeflash_vitest_config = _ensure_codeflash_vitest_config(project_root)
if codeflash_vitest_config:
Expand Down Expand Up @@ -520,6 +520,9 @@ def run_vitest_benchmarking_tests(
) -> tuple[Path, subprocess.CompletedProcess]:
"""Run Vitest benchmarking tests with external looping from Python.

NOTE: This function MUST use benchmarking_file_path (perf tests with capturePerf),
NOT instrumented_behavior_file_path (behavior tests with capture).

Uses external process-level looping to run tests multiple times and
collect timing data. This matches the Python pytest approach where
looping is controlled externally for simplicity.
Expand All @@ -544,6 +547,26 @@ def run_vitest_benchmarking_tests(
# Get performance test files
test_files = [Path(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]

# Log test file selection
total_test_files = len(test_paths.test_files)
perf_test_files = len(test_files)
logger.debug(
f"Vitest benchmark test file selection: {perf_test_files}/{total_test_files} have benchmarking_file_path"
)
if perf_test_files == 0:
logger.warning("No perf test files found! Cannot run benchmarking tests.")
for tf in test_paths.test_files:
logger.warning(
f"Test file: behavior={tf.instrumented_behavior_file_path}, perf={tf.benchmarking_file_path}"
)
elif perf_test_files < total_test_files:
for tf in test_paths.test_files:
if not tf.benchmarking_file_path:
logger.warning(f"Missing benchmarking_file_path: behavior={tf.instrumented_behavior_file_path}")
else:
for tf in test_files[:3]: # Log first 3 perf test files
logger.debug(f"Using perf test file: {tf}")

# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
project_root = _find_vitest_project_root(test_files[0])
Expand Down Expand Up @@ -574,14 +597,25 @@ def run_vitest_benchmarking_tests(
vitest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"

# Set test module for marker identification (use first test file as reference)
if test_files:
test_module_path = str(
test_files[0].relative_to(effective_cwd)
if test_files[0].is_relative_to(effective_cwd)
else test_files[0].name
)
vitest_env["CODEFLASH_TEST_MODULE"] = test_module_path
logger.debug(f"[VITEST-BENCH] Set CODEFLASH_TEST_MODULE={test_module_path}")

# Total timeout for the entire benchmark run
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)

logger.debug(f"Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
logger.debug(f"[VITEST-BENCH] Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
logger.debug(
f"Vitest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
f"[VITEST-BENCH] Config: min_loops={min_loops}, max_loops={max_loops}, "
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
)
logger.debug(f"[VITEST-BENCH] Environment: CODEFLASH_PERF_LOOP_COUNT={vitest_env.get('CODEFLASH_PERF_LOOP_COUNT')}")

total_start_time = time.time()

Expand All @@ -606,7 +640,27 @@ def run_vitest_benchmarking_tests(
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")

wall_clock_seconds = time.time() - total_start_time
logger.debug(f"Vitest benchmarking completed in {wall_clock_seconds:.2f}s")
logger.debug(f"[VITEST-BENCH] Completed in {wall_clock_seconds:.2f}s, returncode={result.returncode}")

# Debug: Check for END markers with duration (perf test format)
if result.stdout:
import re

perf_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:(\d+):[^:]+:(\d+)######!")
perf_matches = list(perf_end_pattern.finditer(result.stdout))
if perf_matches:
loop_indices = [int(m.group(1)) for m in perf_matches]
logger.debug(
f"[VITEST-BENCH] Found {len(perf_matches)} perf END markers in stdout, "
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
)
else:
logger.debug(f"[VITEST-BENCH] No perf END markers found in stdout (len={len(result.stdout)})")
# Check if there are behavior END markers instead
behavior_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:\d+:[^#]+######!")
behavior_matches = list(behavior_end_pattern.finditer(result.stdout))
if behavior_matches:
logger.debug(f"[VITEST-BENCH] Found {len(behavior_matches)} behavior END markers instead (no duration)")

return result_file_path, result

Expand Down
7 changes: 7 additions & 0 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,6 +2363,12 @@ def establish_original_code_baseline(
)
console.rule()
with progress_bar("Running performance benchmarks..."):
logger.debug(
f"[BENCHMARK-START] Starting benchmarking tests with {len(self.test_files.test_files)} test files"
)
for idx, tf in enumerate(self.test_files.test_files):
logger.debug(f"[BENCHMARK-FILES] Test file {idx}: perf_file={tf.benchmarking_file_path}")

if self.function_to_optimize.is_async and is_python():
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

Expand All @@ -2380,6 +2386,7 @@ def establish_original_code_baseline(
enable_coverage=False,
code_context=code_context,
)
logger.debug(f"[BENCHMARK-DONE] Got {len(benchmarking_results.test_results)} benchmark results")
finally:
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
Expand Down
1 change: 1 addition & 0 deletions codeflash/verification/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def run_benchmarking_tests(
pytest_max_loops: int = 100_000,
js_project_root: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
logger.debug(f"run_benchmarking_tests called: framework={test_framework}, num_files={len(test_paths.test_files)}")
# Check if there's a language support for this test framework that implements run_benchmarking_tests
language_support = get_language_support_by_framework(test_framework)
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):
Expand Down
Loading
Loading