diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py index ab6bdb9be..965799682 100644 --- a/diffsynth/diffusion/logger.py +++ b/diffsynth/diffusion/logger.py @@ -17,14 +17,7 @@ def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_ste def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): - accelerator.wait_for_everyone() - state_dict = accelerator.get_state_dict(model) - if accelerator.is_main_process: - state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) - state_dict = self.state_dict_converter(state_dict) - os.makedirs(self.output_path, exist_ok=True) - path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") - accelerator.save(state_dict, path, safe_serialization=True) + self.save_model(accelerator, model, f"epoch-{epoch_id}.safetensors") def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):