Skip to content

Commit a718544

Browse files
add new unittests
1 parent 245bb7d commit a718544

13 files changed

Lines changed: 2577 additions & 0 deletions

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies = [
2323
[project.optional-dependencies]
2424
dev = [
2525
"pytest>=9.0.0",
26+
"pytest-asyncio>=1.0.0",
2627
"black>=25.1.0",
2728
"isort>=6.0.0",
2829
]

tests/test_ckpt.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# ===--------------------------------------------------------------------------------------===#
2+
#
3+
# Part of the CodeEvolve Project, under the Apache License v2.0.
4+
# See https://github.com/inter-co/science-codeevolve/blob/main/LICENSE for license information.
5+
# SPDX-License-Identifier: Apache-2.0
6+
#
7+
# ===--------------------------------------------------------------------------------------===#
8+
#
9+
# This file implements unit tests for checkpointing routines.
10+
#
11+
# ===--------------------------------------------------------------------------------------===#
12+
13+
import json
14+
import logging
15+
from pathlib import Path
16+
from typing import Any, Dict, Optional
17+
18+
import pytest
19+
20+
from codeevolve.database import Program, ProgramDatabase
21+
from codeevolve.islands.sync import GlobalBestProg
22+
from codeevolve.scheduler import ExponentialDecayScheduler, ExplorationRateScheduler
23+
from codeevolve.utils.ckpt import load_ckpt, load_run_metadata, save_ckpt, save_run_metadata
24+
from codeevolve.utils.constants import RUN_METADATA_FILE
25+
26+
27+
# ---------------------------------------------------------------------------
28+
# Helpers
29+
# ---------------------------------------------------------------------------
30+
31+
32+
def _make_db_with_program(island_id: int = 0) -> ProgramDatabase:
33+
"""Creates a ProgramDatabase with one program for testing."""
34+
db: ProgramDatabase = ProgramDatabase(id=island_id, seed=42)
35+
prog: Program = Program(
36+
id="test_prog",
37+
code="def f(): return 1",
38+
language="python",
39+
fitness=10.0,
40+
island_found=island_id,
41+
iteration_found=0,
42+
generation=0,
43+
returncode=0,
44+
eval_metrics={"fitness": 10.0},
45+
)
46+
db.add(prog)
47+
return db
48+
49+
50+
# ---------------------------------------------------------------------------
51+
# save_ckpt / load_ckpt
52+
# ---------------------------------------------------------------------------
53+
54+
55+
class TestCheckpointing:
56+
"""Test suite for checkpoint save and load operations."""
57+
58+
def test_save_and_load_ckpt(self, tmp_path: Path):
59+
"""Tests that checkpoint round-trip preserves database state."""
60+
sol_db: ProgramDatabase = _make_db_with_program()
61+
prompt_db: ProgramDatabase = ProgramDatabase(id=0, seed=42)
62+
prompt_prog: Program = Program(
63+
id="prompt1",
64+
code="You are an expert.",
65+
language="text",
66+
fitness=0.0,
67+
iteration_found=0,
68+
generation=0,
69+
)
70+
prompt_db.add(prompt_prog)
71+
72+
evolve_state: Dict[str, Any] = {
73+
"early_stop_counter": 3,
74+
"best_fit_hist": [1.0, 2.0, 10.0],
75+
"avg_fit_hist": [0.5, 1.0, 5.0],
76+
"errors": [],
77+
"tok_usage": [],
78+
"exploration": [True, False, True],
79+
}
80+
81+
logger: logging.Logger = logging.getLogger("test_ckpt")
82+
83+
best_sol_path: Path = tmp_path / "best_sol.py"
84+
best_prompt_path: Path = tmp_path / "best_prompt.txt"
85+
ckpt_dir: Path = tmp_path / "ckpt"
86+
ckpt_dir.mkdir()
87+
88+
save_ckpt(
89+
curr_epoch=10,
90+
prompt_db=prompt_db,
91+
sol_db=sol_db,
92+
evolve_state=evolve_state,
93+
scheduler=None,
94+
best_sol_path=best_sol_path,
95+
best_prompt_path=best_prompt_path,
96+
ckpt_dir=ckpt_dir,
97+
logger=logger,
98+
)
99+
100+
assert best_sol_path.exists()
101+
assert best_prompt_path.exists()
102+
assert (ckpt_dir / "ckpt_10.pkl").exists()
103+
104+
loaded_prompt_db: Optional[ProgramDatabase]
105+
loaded_sol_db: Optional[ProgramDatabase]
106+
loaded_state: Optional[Dict[str, Any]]
107+
loaded_sched: Optional[ExplorationRateScheduler]
108+
loaded_prompt_db, loaded_sol_db, loaded_state, loaded_sched = load_ckpt(10, ckpt_dir)
109+
110+
assert loaded_sol_db is not None
111+
assert loaded_prompt_db is not None
112+
assert loaded_state is not None
113+
assert loaded_sched is None
114+
assert loaded_sol_db.best_prog_id == "test_prog"
115+
assert loaded_state["early_stop_counter"] == 3
116+
117+
def test_save_and_load_with_scheduler(self, tmp_path: Path):
118+
"""Tests checkpoint round-trip with a scheduler."""
119+
sol_db: ProgramDatabase = _make_db_with_program()
120+
prompt_db: ProgramDatabase = ProgramDatabase(id=0, seed=42)
121+
prompt_db.add(Program(id="pr", code="p", language="text"))
122+
123+
scheduler: ExponentialDecayScheduler = ExponentialDecayScheduler(
124+
exploration_rate=0.5, max_rate=1.0, min_rate=0.01, decay_weight=0.99
125+
)
126+
127+
logger: logging.Logger = logging.getLogger("test_ckpt_sched")
128+
ckpt_dir: Path = tmp_path / "ckpt"
129+
ckpt_dir.mkdir()
130+
131+
save_ckpt(
132+
curr_epoch=5,
133+
prompt_db=prompt_db,
134+
sol_db=sol_db,
135+
evolve_state={"early_stop_counter": 0, "best_fit_hist": [], "avg_fit_hist": [], "errors": [], "tok_usage": [], "exploration": []},
136+
scheduler=scheduler,
137+
best_sol_path=tmp_path / "best.py",
138+
best_prompt_path=tmp_path / "best_prompt.txt",
139+
ckpt_dir=ckpt_dir,
140+
logger=logger,
141+
)
142+
143+
_, _, _, loaded_sched = load_ckpt(5, ckpt_dir)
144+
assert loaded_sched is not None
145+
assert isinstance(loaded_sched, ExponentialDecayScheduler)
146+
assert loaded_sched.decay_weight == 0.99
147+
148+
def test_best_files_content(self, tmp_path: Path):
149+
"""Tests that best solution and prompt files contain correct code."""
150+
sol_db: ProgramDatabase = _make_db_with_program()
151+
prompt_db: ProgramDatabase = ProgramDatabase(id=0, seed=42)
152+
prompt_db.add(Program(id="pr", code="Expert prompt.", language="text"))
153+
154+
logger: logging.Logger = logging.getLogger("test_ckpt_content")
155+
ckpt_dir: Path = tmp_path / "ckpt"
156+
ckpt_dir.mkdir()
157+
158+
best_sol_path: Path = tmp_path / "best_sol.py"
159+
best_prompt_path: Path = tmp_path / "best_prompt.txt"
160+
161+
save_ckpt(
162+
curr_epoch=1,
163+
prompt_db=prompt_db,
164+
sol_db=sol_db,
165+
evolve_state={"early_stop_counter": 0, "best_fit_hist": [], "avg_fit_hist": [], "errors": [], "tok_usage": [], "exploration": []},
166+
scheduler=None,
167+
best_sol_path=best_sol_path,
168+
best_prompt_path=best_prompt_path,
169+
ckpt_dir=ckpt_dir,
170+
logger=logger,
171+
)
172+
173+
sol_content: str = best_sol_path.read_text()
174+
prompt_content: str = best_prompt_path.read_text()
175+
assert "def f(): return 1" in sol_content
176+
assert "Expert prompt." in prompt_content
177+
178+
179+
# ---------------------------------------------------------------------------
180+
# Run metadata
181+
# ---------------------------------------------------------------------------
182+
183+
184+
class TestRunMetadata:
185+
"""Test suite for run metadata save and load operations."""
186+
187+
def test_save_and_load_metadata(self, tmp_path: Path):
188+
"""Tests save/load round-trip for run metadata."""
189+
best_sol: GlobalBestProg = GlobalBestProg()
190+
best_sol.fitness.value = 42.0
191+
best_sol.iteration_found.value = 10
192+
best_sol.island_found.value = 0
193+
best_sol.depth.value = 5
194+
195+
save_run_metadata(tmp_path, epoch=10, elapsed_time=120.5, cpu_count=8, global_best_sol=best_sol)
196+
197+
metadata: Optional[Dict[str, Any]] = load_run_metadata(tmp_path, epoch=10)
198+
assert metadata is not None
199+
assert metadata["elapsed_time"] == 120.5
200+
assert metadata["cpu_count"] == 8
201+
assert metadata["best_sol"]["fitness"] == 42.0
202+
203+
def test_load_nonexistent_metadata(self, tmp_path: Path):
204+
"""Tests loading metadata when file doesn't exist returns None."""
205+
metadata: Optional[Dict[str, Any]] = load_run_metadata(tmp_path, epoch=99)
206+
assert metadata is None
207+
208+
def test_load_missing_epoch(self, tmp_path: Path):
209+
"""Tests loading metadata for a missing epoch returns empty dict."""
210+
best_sol: GlobalBestProg = GlobalBestProg()
211+
save_run_metadata(tmp_path, epoch=10, elapsed_time=100.0, cpu_count=4, global_best_sol=best_sol)
212+
213+
metadata: Optional[Dict[str, Any]] = load_run_metadata(tmp_path, epoch=99)
214+
assert metadata == {}
215+
216+
def test_metadata_accumulates(self, tmp_path: Path):
217+
"""Tests that multiple saves accumulate in the same file."""
218+
best_sol: GlobalBestProg = GlobalBestProg()
219+
save_run_metadata(tmp_path, epoch=10, elapsed_time=100.0, cpu_count=4, global_best_sol=best_sol)
220+
save_run_metadata(tmp_path, epoch=20, elapsed_time=200.0, cpu_count=4, global_best_sol=best_sol)
221+
222+
metadata_file: Path = tmp_path / RUN_METADATA_FILE
223+
with open(metadata_file, "r") as f:
224+
data: Dict[str, Any] = json.load(f)
225+
226+
assert "10" in data
227+
assert "20" in data

0 commit comments

Comments
 (0)