diff --git a/openevolve/config.py b/openevolve/config.py index bef193da2..35056e454 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -404,6 +404,7 @@ class Config: # General settings max_iterations: int = 10000 checkpoint_interval: int = 100 + checkpoint_on_improvement: bool = False log_level: str = "INFO" log_dir: Optional[str] = None random_seed: Optional[int] = 42 diff --git a/openevolve/process_parallel.py b/openevolve/process_parallel.py index a2fd6592a..71c081542 100644 --- a/openevolve/process_parallel.py +++ b/openevolve/process_parallel.py @@ -384,6 +384,7 @@ def _serialize_config(self, config: Config) -> dict: "evaluator": asdict(config.evaluator), "max_iterations": config.max_iterations, "checkpoint_interval": config.checkpoint_interval, + "checkpoint_on_improvement": config.checkpoint_on_improvement, "log_level": config.log_level, "log_dir": config.log_dir, "random_seed": config.random_seed, @@ -658,7 +659,8 @@ async def run_evolution( self._warned_about_combined_score = True # Check for new best - if self.database.best_program_id == child_program.id: + is_new_best = self.database.best_program_id == child_program.id + if is_new_best: logger.info( f"🌟 New best solution found at iteration {completed_iteration}: " f"{child_program.id}" @@ -666,13 +668,24 @@ async def run_evolution( # Checkpoint callback # Don't checkpoint at iteration 0 (that's just the initial program) - if ( + interval_hit = ( completed_iteration > 0 and completed_iteration % self.config.checkpoint_interval == 0 - ): - logger.info( - f"Checkpoint interval reached at iteration {completed_iteration}" - ) + ) + improvement_hit = ( + completed_iteration > 0 + and self.config.checkpoint_on_improvement + and is_new_best + ) + if interval_hit or improvement_hit: + if interval_hit: + logger.info( + f"Checkpoint interval reached at iteration {completed_iteration}" + ) + else: + logger.info( + f"Checkpointing new best solution at iteration {completed_iteration}" + ) self.database.log_island_status() if checkpoint_callback: checkpoint_callback(completed_iteration) diff --git a/tests/test_process_parallel.py b/tests/test_process_parallel.py index 8cdd525b3..44ae778e2 100644 --- a/tests/test_process_parallel.py +++ b/tests/test_process_parallel.py @@ -153,6 +153,73 @@ async def run_test(): # Run the async test asyncio.run(run_test()) + def test_checkpoint_on_improvement_only_fires_for_new_best(self): + """Test checkpoint callback fires on improvement when interval is not reached""" + + async def run_test(): + self.config.checkpoint_on_improvement = True + self.config.checkpoint_interval = 10000 + controller = ProcessParallelController(self.config, self.eval_file, self.database) + checkpoint_calls = [] + + with patch.object(controller, "_submit_iteration") as mock_submit: + mock_future1 = MagicMock() + mock_result1 = SerializableResult( + child_program_dict={ + "id": "child_best", + "code": "def evolved(): return 1", + "language": "python", + "parent_id": "test_0", + "generation": 1, + "metrics": {"score": 1.0, "performance": 1.0}, + "iteration_found": 1, + "metadata": {"changes": "improved", "island": 0}, + }, + parent_id="test_0", + iteration_time=0.1, + iteration=1, + target_island=0, + ) + mock_future1.done.return_value = True + mock_future1.result.return_value = mock_result1 + mock_future1.cancel.return_value = True + + mock_future2 = MagicMock() + mock_result2 = SerializableResult( + child_program_dict={ + "id": "child_not_best", + "code": "def evolved(): return 0", + "language": "python", + "parent_id": "test_0", + "generation": 1, + "metrics": {"score": 0.1, "performance": 0.1}, + "iteration_found": 2, + "metadata": {"changes": "not improved", "island": 1}, + }, + parent_id="test_0", + iteration_time=0.1, + iteration=2, + target_island=1, + ) + mock_future2.done.return_value = True + mock_future2.result.return_value = mock_result2 + mock_future2.cancel.return_value = True + + mock_submit.side_effect = [mock_future1, mock_future2] + + controller.start() + await controller.run_evolution( + start_iteration=1, + max_iterations=2, + target_score=None, + checkpoint_callback=checkpoint_calls.append, + ) + controller.stop() + + self.assertEqual(checkpoint_calls, [1]) + + asyncio.run(run_test()) + def test_request_shutdown(self): """Test graceful shutdown request""" controller = ProcessParallelController(self.config, self.eval_file, self.database)