diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1fa4db90d995..f333110a2762 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -2116,7 +2116,14 @@ def from_pipe(cls, pipeline, **kwargs): """ original_config = dict(pipeline.config) - torch_dtype = kwargs.pop("torch_dtype", torch.float32) + torch_dtype = kwargs.pop("torch_dtype", None) + if torch_dtype is None: + dtypes = set() + for component in pipeline.components.values(): + if isinstance(component, torch.nn.Module): + dtypes.add(component.dtype) + if len(dtypes) == 1: + torch_dtype = dtypes.pop() trust_remote_code = kwargs.pop("trust_remote_code", False) # derive the pipeline class to instantiate diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 6d9e68197976..c8dbcf616bd3 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -12,6 +12,7 @@ AnimateDiffVideoToVideoPipeline, AutoencoderKL, DDIMScheduler, + DiffusionPipeline, MotionAdapter, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, @@ -19,6 +20,7 @@ UNet2DConditionModel, ) from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings +from diffusers.models.modeling_utils import ModelMixin from ..testing_utils import require_torch_accelerator, torch_device @@ -950,3 +952,41 @@ def test_deterministic_dtype(self): pipe_dtype, f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.", ) + + +class FromPipeDtypeTests(unittest.TestCase): + def test_from_pipe_preserves_dtype_by_default(self): + class DummyComponent(ModelMixin): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + class SourcePipeline(DiffusionPipeline): + def __init__(self, unet: DummyComponent, vae: DummyComponent): + super().__init__() + self.register_modules(unet=unet, vae=vae) + + class TargetPipeline(DiffusionPipeline): + def __init__(self, unet: DummyComponent, vae: DummyComponent): + super().__init__() + self.register_modules(unet=unet, vae=vae) + + unet = DummyComponent().to(dtype=torch.float16) + vae = DummyComponent().to(dtype=torch.float16) + pipe = SourcePipeline(unet=unet, vae=vae) + + new_pipe = TargetPipeline.from_pipe(pipe) + + self.assertEqual(new_pipe.dtype, torch.float16) + for name, component in new_pipe.components.items(): + if isinstance(component, torch.nn.Module): + self.assertEqual( + component.dtype, + torch.float16, + f"Component {name} dtype was not preserved after from_pipe.", + ) + + # components are shared with the original pipeline and must not be cast to float32 + self.assertEqual(pipe.dtype, torch.float16) + self.assertIs(new_pipe.unet, pipe.unet) + self.assertIs(new_pipe.vae, pipe.vae)