Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/languages/javascript/test_support_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def test_passes_loop_parameters(self, mock_vitest_runner: MagicMock, js_support:

call_kwargs = mock_vitest_runner.call_args.kwargs
assert call_kwargs["min_loops"] == 10
# JS/TS uses JS_BENCHMARKING_MAX_LOOPS (5_000) regardless of passed value
# JS/TS uses JS_BENCHMARKING_MAX_LOOPS (1_000) regardless of passed value
# Actual loop count is limited by target_duration, not max_loops
assert call_kwargs["max_loops"] == 5_000
assert call_kwargs["max_loops"] == 1_000
assert call_kwargs["target_duration_ms"] == 5000


Expand Down
8 changes: 5 additions & 3 deletions tests/test_languages/test_javascript_test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,15 +907,17 @@ def test_reporter_produces_valid_junit_xml(self):

# Create a Node.js script that exercises the reporter with mock data
test_script = Path(tmpdir) / "test_reporter.js"
reporter_path_js = reporter_path.as_posix()
output_file_js = output_file.as_posix()
test_script.write_text(f"""
// Set env vars BEFORE requiring reporter (matches real Jest behavior)
process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file}';
process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file_js}';
process.env.JEST_JUNIT_CLASSNAME = '{{filepath}}';
process.env.JEST_JUNIT_SUITE_NAME = '{{filepath}}';
process.env.JEST_JUNIT_ADD_FILE_ATTRIBUTE = 'true';
process.env.JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT = 'true';

const Reporter = require('{reporter_path}');
const Reporter = require('{reporter_path_js}');

// Mock Jest globalConfig
const globalConfig = {{ rootDir: '/tmp/project' }};
Expand Down Expand Up @@ -960,7 +962,7 @@ def test_reporter_produces_valid_junit_xml(self):
reporter.onRunComplete([], results);

console.log('OK');
""")
""", encoding="utf-8")

result = subprocess.run(
["node", str(test_script)],
Expand Down
25 changes: 10 additions & 15 deletions tests/test_trace_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import shutil
import sqlite3
import time
from pathlib import Path

import pytest
Expand All @@ -18,6 +17,7 @@ def test_trace_benchmarks() -> None:
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
tests_root = project_root / "tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
conn: sqlite3.Connection | None = None
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
Expand Down Expand Up @@ -121,8 +121,8 @@ def test_trace_benchmarks() -> None:
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()
conn = None
generate_replay_test(output_file, replay_tests_dir)
test_class_sort_path = replay_tests_dir / Path(
"test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py"
Expand Down Expand Up @@ -217,7 +217,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
"""
assert test_sort_path.read_text("utf-8").strip() == test_sort_code.strip()
finally:
# cleanup
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)
shutil.rmtree(replay_tests_dir)

Expand All @@ -231,6 +232,7 @@ def test_trace_multithreaded_benchmark() -> None:
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
conn: sqlite3.Connection | None = None
try:
# check contents of trace file
# connect to database
Expand All @@ -244,8 +246,6 @@ def test_trace_multithreaded_benchmark() -> None:
)
function_calls = cursor.fetchall()

conn.close()

# Assert the length of function calls
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
Expand Down Expand Up @@ -281,11 +281,9 @@ def test_trace_multithreaded_benchmark() -> None:
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()

finally:
# cleanup
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)


Expand All @@ -296,6 +294,7 @@ def test_trace_benchmark_decorator() -> None:
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
conn: sqlite3.Connection | None = None
try:
# check contents of trace file
# connect to database
Expand Down Expand Up @@ -352,11 +351,7 @@ def test_trace_benchmark_decorator() -> None:
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
# Close connection
cursor.close()
conn.close()
time.sleep(2)
finally:
# cleanup
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)
time.sleep(1)
Loading