diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index ee94b1ebdb..3b2d6ec447 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import inspect import math import warnings from abc import ABC, abstractmethod @@ -861,6 +862,42 @@ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] self.scheduler = scheduler + @staticmethod + def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool: + try: + return kwarg in inspect.signature(scheduler.step).parameters + except (TypeError, ValueError): + return False + + @staticmethod + def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor: + if isinstance(step_output, tuple): + return step_output[0] + if isinstance(step_output, Mapping): + return step_output["prev_sample"] + if hasattr(step_output, "prev_sample"): + return step_output.prev_sample + raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.") + + def _scheduler_step( + self, + scheduler: Scheduler, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + next_timestep: int | torch.Tensor | None = None, + ) -> torch.Tensor: + step_kwargs = {} + if self._scheduler_step_supports_kwarg(scheduler, "return_dict"): + step_kwargs["return_dict"] = False + + if isinstance(scheduler, RFlowScheduler): + step_output = scheduler.step(model_output, timestep, sample, next_timestep, **step_kwargs) # type: ignore + else: + step_output = scheduler.step(model_output, timestep, sample, **step_kwargs) # type: ignore + + return self._get_previous_sample_from_step_output(step_output) + def __call__( # type: ignore[override] self, inputs: torch.Tensor, @@ -940,7 +977,12 @@ def sample( scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -984,10 +1026,9 @@ def sample( model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 2. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1046,7 +1087,7 @@ def get_likelihood( total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffusion_model = ( partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) @@ -1509,7 +1550,12 @@ def sample( # type: ignore[override] scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -1583,10 +1629,9 @@ def sample( # type: ignore[override] model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 3. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1647,7 +1692,7 @@ def get_likelihood( # type: ignore[override] total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 81874ed3a8..5099a482e0 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -55,6 +55,31 @@ ] +class DiffusersLikeSchedulerOutput: + def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None: + self.prev_sample = prev_sample + self.pred_original_sample = pred_original_sample + + +class DiffusersStyleDDPMScheduler(DDPMScheduler): + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator: torch.Generator | None = None, + return_dict: bool = True, + ): + prev_sample, pred_original_sample = super().step( + model_output=model_output, timestep=timestep, sample=sample, generator=generator + ) + if return_dict: + return DiffusersLikeSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + return prev_sample, pred_original_sample + + class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -126,6 +151,23 @@ def test_ddpm_sampler(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_diffusers_style_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(sample.shape, noise.shape) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, input_shape): diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index ab80363cde..e12e2b963c 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -313,6 +313,33 @@ ], ] +TEST_CASES_DIFFUSERS = [TEST_CASES[0]] + + +class DiffusersLikeSchedulerOutput: + def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None: + self.prev_sample = prev_sample + self.pred_original_sample = pred_original_sample + + +class DiffusersStyleDDPMScheduler(DDPMScheduler): + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator: torch.Generator | None = None, + return_dict: bool = True, + ): + prev_sample, pred_original_sample = super().step( + model_output=model_output, timestep=timestep, sample=sample, generator=generator + ) + if return_dict: + return DiffusersLikeSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + return prev_sample, pred_original_sample + class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -414,6 +441,37 @@ def test_sample_shape( ) self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES_DIFFUSERS) + @skipUnless(has_einops, "Requires einops") + def test_diffusers_style_ddpm_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + stage_1 = VQVAE(**autoencoder_params) + + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sample_shape_with_cfg(