Skip to content
6 changes: 3 additions & 3 deletions src/diffusers/models/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ def _downsample_2d(
stride_value = [factor, factor]
upfirdn_input = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
pad=((pad_value + 1) // 2, pad_value // 2),
)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
Expand Down Expand Up @@ -392,7 +392,7 @@ def downsample_2d(
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,14 @@ def _upsample_2d(

output = upfirdn2d_native(
inverse_conv,
torch.tensor(kernel, device=inverse_conv.device),
kernel.to(device=inverse_conv.device, dtype=inverse_conv.dtype),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
Expand Down Expand Up @@ -508,7 +508,7 @@ def upsample_2d(
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import AutoencoderKLTemporalDecoder
Expand Down Expand Up @@ -63,7 +64,16 @@ def get_dummy_inputs(self) -> dict:


class TestAutoencoderKLTemporalDecoder(AutoencoderKLTemporalDecoderTesterConfig, ModelTesterMixin):
pass
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# fp16/bf16 convolutions are nondeterministic across the two model instances, so relax the tolerance.
super().test_from_save_pretrained_dtype_inference(
tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2
)


class TestAutoencoderKLTemporalDecoderTraining(AutoencoderKLTemporalDecoderTesterConfig, TrainingTesterMixin):
Expand Down
7 changes: 6 additions & 1 deletion tests/models/autoencoders/test_models_autoencoder_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ def get_dummy_inputs(self) -> dict:


class TestAutoencoderTiny(AutoencoderTinyTesterConfig, ModelTesterMixin):
pass
@pytest.mark.skip(
"`forward` round-trips the latents through a uint8 byte tensor (`.byte()` / `/ 255.0`), which upcasts to "
"float32 regardless of the model dtype, so full fp16/bf16 forward inference is not possible."
)
def test_from_save_pretrained_dtype_inference(self):
pass


class TestAutoencoderTinyTraining(AutoencoderTinyTesterConfig, TrainingTesterMixin):
Expand Down
4 changes: 4 additions & 0 deletions tests/models/controlnets/test_models_controlnet_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ def test_training(self):
def test_training_with_ema(self):
super().test_training_with_ema()

@pytest.mark.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss.")
def test_mixed_precision_training(self):
super().test_mixed_precision_training()

@pytest.mark.skip("ControlNet output doesn't have .sample attribute.")
def test_gradient_checkpointing_equivalence(self):
super().test_gradient_checkpointing_equivalence()
Expand Down
14 changes: 9 additions & 5 deletions tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs
if isinstance(inputs, dict):
return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()}
if isinstance(inputs, list):
return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs]
if isinstance(inputs, (list, tuple)):
# Preserve the container type so models that branch on it (e.g. `isinstance(..., tuple)`) still see a tuple.
return type(inputs)(cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs)

return inputs

Expand Down Expand Up @@ -495,9 +496,12 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,
else:
assert param.data.dtype == dtype

inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype)
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]
# Fetch inputs separately for each forward so that models consuming a generator (e.g. stochastic decoders)
# see the same, freshly-seeded RNG state in both passes instead of sharing a single advancing generator.
output = model(**cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype), return_dict=False)[0]
output_loaded = model_loaded(
**cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype), return_dict=False
)[0]

assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
Expand Down
12 changes: 11 additions & 1 deletion tests/models/transformers/test_models_transformer_chronoedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import ChronoEditTransformer3DModel
Expand Down Expand Up @@ -92,7 +93,16 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:


class TestChronoEditTransformer(ChronoEditTransformerTesterConfig, ModelTesterMixin):
pass
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Modules kept in fp32 diverge from the fully-cast reference, so relax the low-precision tolerance.
super().test_from_save_pretrained_dtype_inference(
tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2
)


class TestChronoEditTransformerTraining(ChronoEditTransformerTesterConfig, TrainingTesterMixin):
Expand Down
12 changes: 11 additions & 1 deletion tests/models/transformers/test_models_transformer_skyreels_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import SkyReelsV2Transformer3DModel
Expand Down Expand Up @@ -87,7 +88,16 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:


class TestSkyReelsV2Transformer(SkyReelsV2TransformerTesterConfig, ModelTesterMixin):
pass
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Modules kept in fp32 diverge from the fully-cast reference, so relax the low-precision tolerance.
super().test_from_save_pretrained_dtype_inference(
tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2
)


class TestSkyReelsV2TransformerTraining(SkyReelsV2TransformerTesterConfig, TrainingTesterMixin):
Expand Down
94 changes: 4 additions & 90 deletions tests/models/transformers/test_models_transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,100 +131,10 @@ def test_determinism(self, atol=1e-5, rtol=0):
first[mask], second[mask], atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)

def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()

model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)

for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert param_1.shape == param_2.shape

inputs_dict = self.get_dummy_inputs()
image = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0])

assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")

@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()

model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")

with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)

assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)

new_model.to(torch_device)

inputs_dict = self.get_dummy_inputs()
image = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0])

assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")

@pytest.mark.skip("Model output `sample` is a list of tensors, not a single tensor.")
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
pass

def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, constants

from ..testing_utils.common import calculate_expected_num_shards, compute_module_persistent_sizes

torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)

base_output = _concat_list_output(model(**inputs_dict, return_dict=False)[0])

model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10))

original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)

try:
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))

expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards

constants.HF_ENABLE_PARALLEL_LOADING = False
self.model_class.from_pretrained(tmp_path).eval().to(torch_device)

constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2

torch.manual_seed(0)
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
model_parallel = model_parallel.to(torch_device)

output_parallel = _concat_list_output(model_parallel(**inputs_dict, return_dict=False)[0])

assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
)
finally:
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers


class TestZImageTransformerMemory(ZImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Z-Image Transformer."""
Expand All @@ -250,6 +160,10 @@ def test_training(self):
def test_training_with_ema(self):
pass

@pytest.mark.skip("Model output `sample` is a list of tensors; mixed-precision training computes MSE loss on it.")
def test_mixed_precision_training(self):
pass

@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
pass
Expand Down
11 changes: 11 additions & 0 deletions tests/models/unets/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def get_dummy_inputs(self) -> dict:


class TestUnet2DModel(Unet2DModelTesterConfig, ModelTesterMixin):
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# fp16/bf16 convolutions are nondeterministic across the two model instances, so relax the tolerance.
super().test_from_save_pretrained_dtype_inference(
tmp_path, dtype, atol=3e-2 if dtype == torch.bfloat16 else 1e-2
)

def test_mid_block_attn_groups(self):
init_dict = self.get_init_dict()
init_dict["add_attention"] = True
Expand Down
11 changes: 11 additions & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ def assert_tensors_close(
if not is_torch_available():
raise ValueError("PyTorch needs to be installed to use this function.")

# Some models (e.g. Z-Image, Cosmos ControlNet) return a list/tuple of tensors as their output. Compare these
# element-wise so the same helper works regardless of whether the output is a single tensor or a sequence.
if isinstance(actual, (list, tuple)) or isinstance(expected, (list, tuple)):
Comment thread
sayakpaul marked this conversation as resolved.
if not (isinstance(actual, (list, tuple)) and isinstance(expected, (list, tuple))):
raise AssertionError(f"{msg} Type mismatch: actual {type(actual)} vs expected {type(expected)}")
if len(actual) != len(expected):
raise AssertionError(f"{msg} Length mismatch: actual {len(actual)} vs expected {len(expected)}")
for i, (a, e) in enumerate(zip(actual, expected)):
assert_tensors_close(a, e, atol=atol, rtol=rtol, msg=f"{msg} [element {i}]")
return

if actual.shape != expected.shape:
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")

Expand Down
Loading