From b90b8249fa32f7d5cffb24cad26e65dfc8c57b5a Mon Sep 17 00:00:00 2001 From: Edward Xiuyuan Yu Date: Mon, 18 May 2026 09:22:23 +0800 Subject: [PATCH] refactor: reuse save_model method remove the duplicated implementaion of model saving in ModelLogger.on_epoch_end with already extract save_model --- diffsynth/diffusion/logger.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py index ab6bdb9be..965799682 100644 --- a/diffsynth/diffusion/logger.py +++ b/diffsynth/diffusion/logger.py @@ -17,14 +17,7 @@ def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_ste def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): - accelerator.wait_for_everyone() - state_dict = accelerator.get_state_dict(model) - if accelerator.is_main_process: - state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) - state_dict = self.state_dict_converter(state_dict) - os.makedirs(self.output_path, exist_ok=True) - path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") - accelerator.save(state_dict, path, safe_serialization=True) + self.save_model(accelerator, model, f"epoch-{epoch_id}.safetensors") def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):