diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 9c15df4569ca..051275ad946c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -442,15 +442,13 @@ def set_timesteps( sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - if self.config.beta_schedule != "squaredcos_cap_v2": - timesteps = timesteps.round() + timesteps = timesteps.round() elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - if self.config.beta_schedule != "squaredcos_cap_v2": - timesteps = timesteps.round() + timesteps = timesteps.round() elif self.config.use_exponential_sigmas: sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -467,6 +465,16 @@ def set_timesteps( else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + # When using karras or lu sigmas with certain beta schedules (e.g. squaredcos_cap_v2), + # the sigma-to-timestep mapping can produce duplicate integer timesteps. Deduplicate + # them to prevent the step index from drifting out of bounds during multistep updates. + timesteps_int = np.round(timesteps).astype(np.int64) + _, unique_indices = np.unique(timesteps_int, return_index=True) + if len(unique_indices) < len(timesteps_int): + unique_indices = np.sort(unique_indices) + timesteps = timesteps[unique_indices] + sigmas = sigmas[unique_indices] + if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero":