diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index ee94b1ebdb..f023db490e 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import inspect import math import warnings from abc import ABC, abstractmethod @@ -861,6 +862,96 @@ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] self.scheduler = scheduler + @staticmethod + def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool: + try: + return kwarg in inspect.signature(scheduler.step).parameters + except (TypeError, ValueError): + return False + + @staticmethod + def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor: + if isinstance(step_output, tuple): + return step_output[0] + if isinstance(step_output, Mapping): + return step_output["prev_sample"] + if hasattr(step_output, "prev_sample"): + return step_output.prev_sample + raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.") + + @staticmethod + def _get_scheduler_name(scheduler: Scheduler) -> str: + if hasattr(scheduler, "_get_name"): + return scheduler._get_name() + return scheduler.__class__.__name__ + + @staticmethod + def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any: + config = getattr(scheduler, "config", None) + if isinstance(config, Mapping): + if name in config: + return config[name] + elif config is not None and hasattr(config, name): + return getattr(config, name) + + if hasattr(scheduler, name): + return getattr(scheduler, name) + return default + + @staticmethod + def _get_posterior_mean( + scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor + ) -> torch.Tensor: + alpha_t = scheduler.alphas[timestep] + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * scheduler.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + return x_0_coefficient * x_0 + x_t_coefficient * x_t + + def _get_posterior_variance( + self, scheduler: Scheduler, timestep: int | torch.Tensor, predicted_variance: torch.Tensor | None = None + ) -> torch.Tensor: + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * scheduler.betas[timestep] + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + + if variance_type == "fixed_small": + variance = torch.clamp(variance, min=1e-20) + elif variance_type == "fixed_large": + variance = scheduler.betas[timestep] + elif variance_type == "learned" and predicted_variance is not None: + return predicted_variance + elif variance_type == "learned_range" and predicted_variance is not None: + min_log = variance + max_log = scheduler.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def _scheduler_step( + self, + scheduler: Scheduler, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + next_timestep: int | torch.Tensor | None = None, + ) -> torch.Tensor: + step_kwargs = {} + if self._scheduler_step_supports_kwarg(scheduler, "return_dict"): + step_kwargs["return_dict"] = False + + if isinstance(scheduler, RFlowScheduler): + step_output = scheduler.step(model_output, timestep, sample, next_timestep, **step_kwargs) # type: ignore + else: + step_output = scheduler.step(model_output, timestep, sample, **step_kwargs) # type: ignore + + return self._get_previous_sample_from_step_output(step_output) + def __call__( # type: ignore[override] self, inputs: torch.Tensor, @@ -940,7 +1031,12 @@ def sample( scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -984,10 +1080,9 @@ def sample( model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 2. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1028,10 +1123,10 @@ def get_likelihood( if not scheduler: scheduler = self.scheduler - if scheduler._get_name() != "DDPMScheduler": + scheduler_name = self._get_scheduler_name(scheduler) + if scheduler_name != "DDPMScheduler": raise NotImplementedError( - f"Likelihood computation is only compatible with DDPMScheduler," - f" you are using {scheduler._get_name()}" + f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}" ) if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1046,7 +1141,7 @@ def get_likelihood( total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffusion_model = ( partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) @@ -1059,7 +1154,8 @@ def get_likelihood( model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted - if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: predicted_variance = None @@ -1072,15 +1168,17 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if scheduler.prediction_type == "epsilon": + prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type") + if prediction_type == "epsilon": pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif scheduler.prediction_type == "sample": + elif prediction_type == "sample": pred_original_sample = model_output - elif scheduler.prediction_type == "v_prediction": + elif prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" - if scheduler.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + if self._get_scheduler_config_value(scheduler, "clip_sample"): + clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0) + pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -1092,11 +1190,15 @@ def get_likelihood( predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] + posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = self._get_posterior_variance( + scheduler=scheduler, timestep=t, predicted_variance=predicted_variance + ) log_posterior_variance = torch.log(posterior_variance) - log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + ) if t == 0: # compute -log p(x_0|x_1) @@ -1509,7 +1611,12 @@ def sample( # type: ignore[override] scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -1583,10 +1690,9 @@ def sample( # type: ignore[override] model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 3. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1631,10 +1737,10 @@ def get_likelihood( # type: ignore[override] if not scheduler: scheduler = self.scheduler - if scheduler._get_name() != "DDPMScheduler": + scheduler_name = self._get_scheduler_name(scheduler) + if scheduler_name != "DDPMScheduler": raise NotImplementedError( - f"Likelihood computation is only compatible with DDPMScheduler," - f" you are using {scheduler._get_name()}" + f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}" ) if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1647,7 +1753,7 @@ def get_likelihood( # type: ignore[override] total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -1680,7 +1786,8 @@ def get_likelihood( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) # get the model's predicted mean, and variance if it is predicted - if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: predicted_variance = None @@ -1693,15 +1800,17 @@ def get_likelihood( # type: ignore[override] # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if scheduler.prediction_type == "epsilon": + prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type") + if prediction_type == "epsilon": pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif scheduler.prediction_type == "sample": + elif prediction_type == "sample": pred_original_sample = model_output - elif scheduler.prediction_type == "v_prediction": + elif prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" - if scheduler.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + if self._get_scheduler_config_value(scheduler, "clip_sample"): + clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0) + pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -1713,11 +1822,15 @@ def get_likelihood( # type: ignore[override] predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] + posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = self._get_posterior_variance( + scheduler=scheduler, timestep=t, predicted_variance=predicted_variance + ) log_posterior_variance = torch.log(posterior_variance) - log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + ) if t == 0: # compute -log p(x_0|x_1) diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 81874ed3a8..9e1a3072dd 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -24,6 +24,7 @@ _, has_scipy = optional_import("scipy") _, has_einops = optional_import("einops") +DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler") TEST_CASES = [ [ @@ -126,6 +127,63 @@ def test_ddpm_sampler(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_call(self): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=[32, 64], + attention_levels=[False, True], + num_res_blocks=1, + num_head_channels=32, + ) + model.to(device) + model.eval() + scheduler = DiffusersDDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="epsilon") + scheduler.set_timesteps(num_inference_steps=50) + inferer = DiffusionInferer(scheduler=scheduler) + + batch_size = 2 + image_size = 32 + inputs = torch.randn(batch_size, 1, image_size, image_size).to(device) + noise = torch.randn_like(inputs) + timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).long().to(device) + with torch.no_grad(): + prediction = inferer(inputs=inputs, diffusion_model=model, noise=noise, timesteps=timesteps) + + self.assertEqual(prediction.shape, inputs.shape) + scheduler.set_timesteps(num_inference_steps=2) + sample = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler, verbose=False) + self.assertEqual(sample.shape, inputs.shape) + + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_get_likelihood(self): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=[8], + norm_num_groups=8, + attention_levels=[True], + num_res_blocks=1, + num_head_channels=8, + ) + model.to(device) + model.eval() + inputs = torch.randn(2, 1, 8, 8).to(device) + scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon") + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=inputs, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, inputs.shape) + self.assertEqual(likelihood.shape[0], inputs.shape[0]) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, input_shape): diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index ab80363cde..23dd594d8e 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -23,6 +23,7 @@ from monai.utils import optional_import _, has_einops = optional_import("einops") +DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler") TEST_CASES = [ [ "AutoencoderKL", @@ -414,6 +415,46 @@ def test_sample_shape( ) self.assertEqual(sample.shape, input_shape) + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_sample_shape(self): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1 = AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = DiffusionModelUNet( + spatial_dims=2, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(1, 3, 4, 4).to(device) + scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon") + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, (1, 1, 8, 8)) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sample_shape_with_cfg(