Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openevolve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ class Config:
# General settings
max_iterations: int = 10000
checkpoint_interval: int = 100
checkpoint_on_improvement: bool = False
log_level: str = "INFO"
log_dir: Optional[str] = None
random_seed: Optional[int] = 42
Expand Down
25 changes: 19 additions & 6 deletions openevolve/process_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def _serialize_config(self, config: Config) -> dict:
"evaluator": asdict(config.evaluator),
"max_iterations": config.max_iterations,
"checkpoint_interval": config.checkpoint_interval,
"checkpoint_on_improvement": config.checkpoint_on_improvement,
"log_level": config.log_level,
"log_dir": config.log_dir,
"random_seed": config.random_seed,
Expand Down Expand Up @@ -658,21 +659,33 @@ async def run_evolution(
self._warned_about_combined_score = True

# Check for new best
if self.database.best_program_id == child_program.id:
is_new_best = self.database.best_program_id == child_program.id
if is_new_best:
logger.info(
f"🌟 New best solution found at iteration {completed_iteration}: "
f"{child_program.id}"
)

# Checkpoint callback
# Don't checkpoint at iteration 0 (that's just the initial program)
if (
interval_hit = (
completed_iteration > 0
and completed_iteration % self.config.checkpoint_interval == 0
):
logger.info(
f"Checkpoint interval reached at iteration {completed_iteration}"
)
)
improvement_hit = (
completed_iteration > 0
and self.config.checkpoint_on_improvement
and is_new_best
)
if interval_hit or improvement_hit:
if interval_hit:
logger.info(
f"Checkpoint interval reached at iteration {completed_iteration}"
)
else:
logger.info(
f"Checkpointing new best solution at iteration {completed_iteration}"
)
self.database.log_island_status()
if checkpoint_callback:
checkpoint_callback(completed_iteration)
Expand Down
67 changes: 67 additions & 0 deletions tests/test_process_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,73 @@ async def run_test():
# Run the async test
asyncio.run(run_test())

def test_checkpoint_on_improvement_only_fires_for_new_best(self):
"""Test checkpoint callback fires on improvement when interval is not reached"""

async def run_test():
self.config.checkpoint_on_improvement = True
self.config.checkpoint_interval = 10000
controller = ProcessParallelController(self.config, self.eval_file, self.database)
checkpoint_calls = []

with patch.object(controller, "_submit_iteration") as mock_submit:
mock_future1 = MagicMock()
mock_result1 = SerializableResult(
child_program_dict={
"id": "child_best",
"code": "def evolved(): return 1",
"language": "python",
"parent_id": "test_0",
"generation": 1,
"metrics": {"score": 1.0, "performance": 1.0},
"iteration_found": 1,
"metadata": {"changes": "improved", "island": 0},
},
parent_id="test_0",
iteration_time=0.1,
iteration=1,
target_island=0,
)
mock_future1.done.return_value = True
mock_future1.result.return_value = mock_result1
mock_future1.cancel.return_value = True

mock_future2 = MagicMock()
mock_result2 = SerializableResult(
child_program_dict={
"id": "child_not_best",
"code": "def evolved(): return 0",
"language": "python",
"parent_id": "test_0",
"generation": 1,
"metrics": {"score": 0.1, "performance": 0.1},
"iteration_found": 2,
"metadata": {"changes": "not improved", "island": 1},
},
parent_id="test_0",
iteration_time=0.1,
iteration=2,
target_island=1,
)
mock_future2.done.return_value = True
mock_future2.result.return_value = mock_result2
mock_future2.cancel.return_value = True

mock_submit.side_effect = [mock_future1, mock_future2]

controller.start()
await controller.run_evolution(
start_iteration=1,
max_iterations=2,
target_score=None,
checkpoint_callback=checkpoint_calls.append,
)
controller.stop()

self.assertEqual(checkpoint_calls, [1])

asyncio.run(run_test())

def test_request_shutdown(self):
"""Test graceful shutdown request"""
controller = ProcessParallelController(self.config, self.eval_file, self.database)
Expand Down
Loading