From c3c0eecc191da6623ab1b71e3c45ceab9bb0210f Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Fri, 8 May 2026 15:59:48 -0700 Subject: [PATCH] feat(config): add checkpoint_on_improvement option to save only on new best Adds a new config flag, checkpoint_on_improvement, that triggers a checkpoint callback whenever a new-best program is found, in addition to the existing checkpoint_interval gate. Default False preserves existing behavior. When the optimization is making slow progress, checkpoint_interval saves work that the next interval would overwrite without any new best. This option lets users say 'only checkpoint when there's actually something new to save.' Wires the flag through the worker config dict (process_parallel.py:386) and adds a unit test that verifies the callback fires for a new-best run and not for a not-best run when checkpoint_interval is set high. Fixes #434 --- openevolve/config.py | 1 + openevolve/process_parallel.py | 25 ++++++++++--- tests/test_process_parallel.py | 67 ++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 6 deletions(-) diff --git a/openevolve/config.py b/openevolve/config.py index bef193da21..35056e4545 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -404,6 +404,7 @@ class Config: # General settings max_iterations: int = 10000 checkpoint_interval: int = 100 + checkpoint_on_improvement: bool = False log_level: str = "INFO" log_dir: Optional[str] = None random_seed: Optional[int] = 42 diff --git a/openevolve/process_parallel.py b/openevolve/process_parallel.py index a2fd6592a9..71c081542e 100644 --- a/openevolve/process_parallel.py +++ b/openevolve/process_parallel.py @@ -384,6 +384,7 @@ def _serialize_config(self, config: Config) -> dict: "evaluator": asdict(config.evaluator), "max_iterations": config.max_iterations, "checkpoint_interval": config.checkpoint_interval, + "checkpoint_on_improvement": config.checkpoint_on_improvement, "log_level": config.log_level, "log_dir": config.log_dir, "random_seed": config.random_seed, @@ -658,7 +659,8 @@ async def run_evolution( self._warned_about_combined_score = True # Check for new best - if self.database.best_program_id == child_program.id: + is_new_best = self.database.best_program_id == child_program.id + if is_new_best: logger.info( f"🌟 New best solution found at iteration {completed_iteration}: " f"{child_program.id}" @@ -666,13 +668,24 @@ async def run_evolution( # Checkpoint callback # Don't checkpoint at iteration 0 (that's just the initial program) - if ( + interval_hit = ( completed_iteration > 0 and completed_iteration % self.config.checkpoint_interval == 0 - ): - logger.info( - f"Checkpoint interval reached at iteration {completed_iteration}" - ) + ) + improvement_hit = ( + completed_iteration > 0 + and self.config.checkpoint_on_improvement + and is_new_best + ) + if interval_hit or improvement_hit: + if interval_hit: + logger.info( + f"Checkpoint interval reached at iteration {completed_iteration}" + ) + else: + logger.info( + f"Checkpointing new best solution at iteration {completed_iteration}" + ) self.database.log_island_status() if checkpoint_callback: checkpoint_callback(completed_iteration) diff --git a/tests/test_process_parallel.py b/tests/test_process_parallel.py index 8cdd525b33..44ae778e2a 100644 --- a/tests/test_process_parallel.py +++ b/tests/test_process_parallel.py @@ -153,6 +153,73 @@ async def run_test(): # Run the async test asyncio.run(run_test()) + def test_checkpoint_on_improvement_only_fires_for_new_best(self): + """Test checkpoint callback fires on improvement when interval is not reached""" + + async def run_test(): + self.config.checkpoint_on_improvement = True + self.config.checkpoint_interval = 10000 + controller = ProcessParallelController(self.config, self.eval_file, self.database) + checkpoint_calls = [] + + with patch.object(controller, "_submit_iteration") as mock_submit: + mock_future1 = MagicMock() + mock_result1 = SerializableResult( + child_program_dict={ + "id": "child_best", + "code": "def evolved(): return 1", + "language": "python", + "parent_id": "test_0", + "generation": 1, + "metrics": {"score": 1.0, "performance": 1.0}, + "iteration_found": 1, + "metadata": {"changes": "improved", "island": 0}, + }, + parent_id="test_0", + iteration_time=0.1, + iteration=1, + target_island=0, + ) + mock_future1.done.return_value = True + mock_future1.result.return_value = mock_result1 + mock_future1.cancel.return_value = True + + mock_future2 = MagicMock() + mock_result2 = SerializableResult( + child_program_dict={ + "id": "child_not_best", + "code": "def evolved(): return 0", + "language": "python", + "parent_id": "test_0", + "generation": 1, + "metrics": {"score": 0.1, "performance": 0.1}, + "iteration_found": 2, + "metadata": {"changes": "not improved", "island": 1}, + }, + parent_id="test_0", + iteration_time=0.1, + iteration=2, + target_island=1, + ) + mock_future2.done.return_value = True + mock_future2.result.return_value = mock_result2 + mock_future2.cancel.return_value = True + + mock_submit.side_effect = [mock_future1, mock_future2] + + controller.start() + await controller.run_evolution( + start_iteration=1, + max_iterations=2, + target_score=None, + checkpoint_callback=checkpoint_calls.append, + ) + controller.stop() + + self.assertEqual(checkpoint_calls, [1]) + + asyncio.run(run_test()) + def test_request_shutdown(self): """Test graceful shutdown request""" controller = ProcessParallelController(self.config, self.eval_file, self.database)