11import shutil
22import sqlite3
3- import time
43from pathlib import Path
54
65import pytest
@@ -18,6 +17,7 @@ def test_trace_benchmarks() -> None:
1817 replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
1918 tests_root = project_root / "tests"
2019 output_file = (benchmarks_root / Path ("test_trace_benchmarks.trace" )).resolve ()
20+ conn : sqlite3 .Connection | None = None
2121 trace_benchmarks_pytest (benchmarks_root , tests_root , project_root , output_file )
2222 assert output_file .exists ()
2323 try :
@@ -121,8 +121,8 @@ def test_trace_benchmarks() -> None:
121121 assert actual [4 ] == expected [4 ], f"Mismatch at index { idx } for benchmark_function_name"
122122 assert actual [5 ] == expected [5 ], f"Mismatch at index { idx } for benchmark_module_path"
123123 assert actual [6 ] == expected [6 ], f"Mismatch at index { idx } for benchmark_line_number"
124- # Close connection
125124 conn .close ()
125+ conn = None
126126 generate_replay_test (output_file , replay_tests_dir )
127127 test_class_sort_path = replay_tests_dir / Path (
128128 "test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py"
@@ -217,7 +217,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
217217"""
218218 assert test_sort_path .read_text ("utf-8" ).strip () == test_sort_code .strip ()
219219 finally :
220- # cleanup
220+ if conn is not None :
221+ conn .close ()
221222 output_file .unlink (missing_ok = True )
222223 shutil .rmtree (replay_tests_dir )
223224
@@ -231,6 +232,7 @@ def test_trace_multithreaded_benchmark() -> None:
231232 output_file = (benchmarks_root / Path ("test_trace_benchmarks.trace" )).resolve ()
232233 trace_benchmarks_pytest (benchmarks_root , tests_root , project_root , output_file )
233234 assert output_file .exists ()
235+ conn : sqlite3 .Connection | None = None
234236 try :
235237 # check contents of trace file
236238 # connect to database
@@ -244,8 +246,6 @@ def test_trace_multithreaded_benchmark() -> None:
244246 )
245247 function_calls = cursor .fetchall ()
246248
247- conn .close ()
248-
249249 # Assert the length of function calls
250250 assert len (function_calls ) == 10 , f"Expected 10 function calls, but got { len (function_calls )} "
251251 function_benchmark_timings = codeflash_benchmark_plugin .get_function_benchmark_timings (output_file )
@@ -281,11 +281,9 @@ def test_trace_multithreaded_benchmark() -> None:
281281 assert actual [4 ] == expected [4 ], f"Mismatch at index { idx } for benchmark_function_name"
282282 assert actual [5 ] == expected [5 ], f"Mismatch at index { idx } for benchmark_module_path"
283283 assert actual [6 ] == expected [6 ], f"Mismatch at index { idx } for benchmark_line_number"
284- # Close connection
285- conn .close ()
286-
287284 finally :
288- # cleanup
285+ if conn is not None :
286+ conn .close ()
289287 output_file .unlink (missing_ok = True )
290288
291289
@@ -296,6 +294,7 @@ def test_trace_benchmark_decorator() -> None:
296294 output_file = (benchmarks_root / Path ("test_trace_benchmarks.trace" )).resolve ()
297295 trace_benchmarks_pytest (benchmarks_root , tests_root , project_root , output_file )
298296 assert output_file .exists ()
297+ conn : sqlite3 .Connection | None = None
299298 try :
300299 # check contents of trace file
301300 # connect to database
@@ -352,11 +351,7 @@ def test_trace_benchmark_decorator() -> None:
352351 assert Path (actual [3 ]).name == Path (expected [3 ]).name , f"Mismatch at index { idx } for file_path"
353352 assert actual [4 ] == expected [4 ], f"Mismatch at index { idx } for benchmark_function_name"
354353 assert actual [5 ] == expected [5 ], f"Mismatch at index { idx } for benchmark_module_path"
355- # Close connection
356- cursor .close ()
357- conn .close ()
358- time .sleep (2 )
359354 finally :
360- # cleanup
355+ if conn is not None :
356+ conn .close ()
361357 output_file .unlink (missing_ok = True )
362- time .sleep (1 )
0 commit comments