diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 871c0ed7ddf7..4c7a8f8c67bb 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -227,7 +227,7 @@ def _downsample_2d( stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( hidden_states, - torch.tensor(kernel, device=hidden_states.device), + kernel.to(device=hidden_states.device, dtype=hidden_states.dtype), pad=((pad_value + 1) // 2, pad_value // 2), ) output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) @@ -235,7 +235,7 @@ def _downsample_2d( pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, - torch.tensor(kernel, device=hidden_states.device), + kernel.to(device=hidden_states.device, dtype=hidden_states.dtype), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) @@ -392,7 +392,7 @@ def downsample_2d( pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, - kernel.to(device=hidden_states.device), + kernel.to(device=hidden_states.device, dtype=hidden_states.dtype), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index cd3986287303..5a185b4d41f0 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -300,14 +300,14 @@ def _upsample_2d( output = upfirdn2d_native( inverse_conv, - torch.tensor(kernel, device=inverse_conv.device), + kernel.to(device=inverse_conv.device, dtype=inverse_conv.dtype), pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), ) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, - torch.tensor(kernel, device=hidden_states.device), + kernel.to(device=hidden_states.device, dtype=hidden_states.dtype), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) @@ -508,7 +508,7 @@ def upsample_2d( pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, - kernel.to(device=hidden_states.device), + kernel.to(device=hidden_states.device, dtype=hidden_states.dtype), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py index 7d4ea24d5502..3958fccae936 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from diffusers import AutoencoderKLTemporalDecoder @@ -63,7 +64,16 @@ def get_dummy_inputs(self) -> dict: class TestAutoencoderKLTemporalDecoder(AutoencoderKLTemporalDecoderTesterConfig, ModelTesterMixin): - pass + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be use for inference with an accelerator", + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # fp16/bf16 convolutions are nondeterministic across the two model instances, so relax the tolerance. + super().test_from_save_pretrained_dtype_inference( + tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2 + ) class TestAutoencoderKLTemporalDecoderTraining(AutoencoderKLTemporalDecoderTesterConfig, TrainingTesterMixin): diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 7fdab4aeb910..43dda6187505 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -76,7 +76,12 @@ def get_dummy_inputs(self) -> dict: class TestAutoencoderTiny(AutoencoderTinyTesterConfig, ModelTesterMixin): - pass + @pytest.mark.skip( + "`forward` round-trips the latents through a uint8 byte tensor (`.byte()` / `/ 255.0`), which upcasts to " + "float32 regardless of the model dtype, so full fp16/bf16 forward inference is not possible." + ) + def test_from_save_pretrained_dtype_inference(self): + pass class TestAutoencoderTinyTraining(AutoencoderTinyTesterConfig, TrainingTesterMixin): diff --git a/tests/models/controlnets/test_models_controlnet_cosmos.py b/tests/models/controlnets/test_models_controlnet_cosmos.py index 9bef488a8106..e7ea6362213d 100644 --- a/tests/models/controlnets/test_models_controlnet_cosmos.py +++ b/tests/models/controlnets/test_models_controlnet_cosmos.py @@ -283,6 +283,10 @@ def test_training(self): def test_training_with_ema(self): super().test_training_with_ema() + @pytest.mark.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss.") + def test_mixed_precision_training(self): + super().test_mixed_precision_training() + @pytest.mark.skip("ControlNet output doesn't have .sample attribute.") def test_gradient_checkpointing_equivalence(self): super().test_gradient_checkpointing_equivalence() diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 5726dba9c600..626f1eb7f1bf 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -135,8 +135,9 @@ def cast_inputs_to_dtype(inputs, current_dtype, target_dtype): return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs if isinstance(inputs, dict): return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()} - if isinstance(inputs, list): - return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs] + if isinstance(inputs, (list, tuple)): + # Preserve the container type so models that branch on it (e.g. `isinstance(..., tuple)`) still see a tuple. + return type(inputs)(cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs) return inputs @@ -495,9 +496,12 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, else: assert param.data.dtype == dtype - inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype) - output = model(**inputs, return_dict=False)[0] - output_loaded = model_loaded(**inputs, return_dict=False)[0] + # Fetch inputs separately for each forward so that models consuming a generator (e.g. stochastic decoders) + # see the same, freshly-seeded RNG state in both passes instead of sharing a single advancing generator. + output = model(**cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype), return_dict=False)[0] + output_loaded = model_loaded( + **cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype), return_dict=False + )[0] assert_tensors_close( output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}" diff --git a/tests/models/transformers/test_models_transformer_chronoedit.py b/tests/models/transformers/test_models_transformer_chronoedit.py index 29fd99b82f7a..8baca5091b98 100644 --- a/tests/models/transformers/test_models_transformer_chronoedit.py +++ b/tests/models/transformers/test_models_transformer_chronoedit.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from diffusers import ChronoEditTransformer3DModel @@ -92,7 +93,16 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: class TestChronoEditTransformer(ChronoEditTransformerTesterConfig, ModelTesterMixin): - pass + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be use for inference with an accelerator", + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # Modules kept in fp32 diverge from the fully-cast reference, so relax the low-precision tolerance. + super().test_from_save_pretrained_dtype_inference( + tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2 + ) class TestChronoEditTransformerTraining(ChronoEditTransformerTesterConfig, TrainingTesterMixin): diff --git a/tests/models/transformers/test_models_transformer_skyreels_v2.py b/tests/models/transformers/test_models_transformer_skyreels_v2.py index 96a43d6f8209..0b895ef799dc 100644 --- a/tests/models/transformers/test_models_transformer_skyreels_v2.py +++ b/tests/models/transformers/test_models_transformer_skyreels_v2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from diffusers import SkyReelsV2Transformer3DModel @@ -87,7 +88,16 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: class TestSkyReelsV2Transformer(SkyReelsV2TransformerTesterConfig, ModelTesterMixin): - pass + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be use for inference with an accelerator", + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # Modules kept in fp32 diverge from the fully-cast reference, so relax the low-precision tolerance. + super().test_from_save_pretrained_dtype_inference( + tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2 + ) class TestSkyReelsV2TransformerTraining(SkyReelsV2TransformerTesterConfig, TrainingTesterMixin): diff --git a/tests/models/transformers/test_models_transformer_z_image.py b/tests/models/transformers/test_models_transformer_z_image.py index 3a0fe18bc692..ad4a081557c5 100644 --- a/tests/models/transformers/test_models_transformer_z_image.py +++ b/tests/models/transformers/test_models_transformer_z_image.py @@ -131,100 +131,10 @@ def test_determinism(self, atol=1e-5, rtol=0): first[mask], second[mask], atol=atol, rtol=rtol, msg="Model outputs are not deterministic" ) - def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): - torch.manual_seed(0) - model = self.model_class(**self.get_init_dict()) - model.to(torch_device) - model.eval() - - model.save_pretrained(tmp_path) - new_model = self.model_class.from_pretrained(tmp_path) - new_model.to(torch_device) - - for param_name in model.state_dict().keys(): - param_1 = model.state_dict()[param_name] - param_2 = new_model.state_dict()[param_name] - assert param_1.shape == param_2.shape - - inputs_dict = self.get_dummy_inputs() - image = _concat_list_output(model(**inputs_dict, return_dict=False)[0]) - new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0]) - - assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") - - @torch.no_grad() - def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): - model = self.model_class(**self.get_init_dict()) - model.to(torch_device) - model.eval() - - model.save_pretrained(tmp_path, variant="fp16") - new_model = self.model_class.from_pretrained(tmp_path, variant="fp16") - - with pytest.raises(OSError) as exc_info: - self.model_class.from_pretrained(tmp_path) - - assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) - - new_model.to(torch_device) - - inputs_dict = self.get_dummy_inputs() - image = _concat_list_output(model(**inputs_dict, return_dict=False)[0]) - new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0]) - - assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") - @pytest.mark.skip("Model output `sample` is a list of tensors, not a single tensor.") def test_outputs_equivalence(self, atol=1e-5, rtol=0): pass - def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0): - from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, constants - - from ..testing_utils.common import calculate_expected_num_shards, compute_module_persistent_sizes - - torch.manual_seed(0) - config = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() - model = self.model_class(**config).eval() - model = model.to(torch_device) - - base_output = _concat_list_output(model(**inputs_dict, return_dict=False)[0]) - - model_size = compute_module_persistent_sizes(model)[""] - max_shard_size = int((model_size * 0.75) / (2**10)) - - original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING - original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None) - - try: - model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB") - assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) - - expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) - actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) - assert actual_num_shards == expected_num_shards - - constants.HF_ENABLE_PARALLEL_LOADING = False - self.model_class.from_pretrained(tmp_path).eval().to(torch_device) - - constants.HF_ENABLE_PARALLEL_LOADING = True - constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2 - - torch.manual_seed(0) - model_parallel = self.model_class.from_pretrained(tmp_path).eval() - model_parallel = model_parallel.to(torch_device) - - output_parallel = _concat_list_output(model_parallel(**inputs_dict, return_dict=False)[0]) - - assert_tensors_close( - base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading" - ) - finally: - constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading - if original_parallel_workers is not None: - constants.HF_PARALLEL_WORKERS = original_parallel_workers - class TestZImageTransformerMemory(ZImageTransformerTesterConfig, MemoryTesterMixin): """Memory optimization tests for Z-Image Transformer.""" @@ -250,6 +160,10 @@ def test_training(self): def test_training_with_ema(self): pass + @pytest.mark.skip("Model output `sample` is a list of tensors; mixed-precision training computes MSE loss on it.") + def test_mixed_precision_training(self): + pass + @pytest.mark.skip("Test is not supported for handling main inputs that are lists.") def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None): pass diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index a5cd8abd873a..0399f4301214 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -74,6 +74,17 @@ def get_dummy_inputs(self) -> dict: class TestUnet2DModel(Unet2DModelTesterConfig, ModelTesterMixin): + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be use for inference with an accelerator", + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # fp16/bf16 convolutions are nondeterministic across the two model instances, so relax the tolerance. + super().test_from_save_pretrained_dtype_inference( + tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2 + ) + def test_mid_block_attn_groups(self): init_dict = self.get_init_dict() init_dict["add_attention"] = True diff --git a/tests/testing_utils.py b/tests/testing_utils.py index a8306b3d65f8..86887d7af6e9 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -165,6 +165,17 @@ def assert_tensors_close( if not is_torch_available(): raise ValueError("PyTorch needs to be installed to use this function.") + # Some models (e.g. Z-Image, Cosmos ControlNet) return a list/tuple of tensors as their output. Compare these + # element-wise so the same helper works regardless of whether the output is a single tensor or a sequence. + if isinstance(actual, (list, tuple)) or isinstance(expected, (list, tuple)): + if not (isinstance(actual, (list, tuple)) and isinstance(expected, (list, tuple))): + raise AssertionError(f"{msg} Type mismatch: actual {type(actual)} vs expected {type(expected)}") + if len(actual) != len(expected): + raise AssertionError(f"{msg} Length mismatch: actual {len(actual)} vs expected {len(expected)}") + for i, (a, e) in enumerate(zip(actual, expected)): + assert_tensors_close(a, e, atol=atol, rtol=rtol, msg=f"{msg} [element {i}]") + return + if actual.shape != expected.shape: raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")