Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -760,15 +760,6 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)

# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)

# Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
Expand All @@ -790,14 +781,20 @@ def check_inputs(
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
transposed_image = [list(t) for t in zip(*image)]
if len(transposed_image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
)
for image_ in transposed_image:
self.check_image(image_, prompt, prompt_embeds)
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)

for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False

Expand All @@ -816,7 +813,10 @@ def check_inputs(
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
raise ValueError(
"A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
"The conditioning scale must be fixed across the batch."
)
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
Expand Down Expand Up @@ -1233,7 +1233,8 @@ def __call__(
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
to a single ControlNet. You can also pass a list of lists to run multiple ControlNets over a batch of
prompts, where each inner list holds one conditioning image per ControlNet for one prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand Down Expand Up @@ -1568,6 +1569,11 @@ def denoising_value_valid(dnv):
elif isinstance(controlnet, MultiControlNetModel):
control_images = []

# Nested lists as ControlNet condition
if isinstance(control_image[0], list):
# Transpose the nested image list
control_image = [list(t) for t in zip(*control_image)]

for control_image_ in control_image:
control_image_ = self.prepare_control_image(
image=control_image_,
Expand Down
201 changes: 201 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
StableDiffusionXLControlNetInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available

from ...testing_utils import (
Expand Down Expand Up @@ -354,3 +355,203 @@ def test_save_load_optional_components(self):

def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)


class StableDiffusionXLMultiControlNetInpaintPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetInpaintPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([])
supports_dduf = False

def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
torch.manual_seed(0)

def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal_(m.weight)
m.bias.data.fill_(1.0)

controlnet1 = ControlNetModel(
block_out_channels=(32, 64),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
conditioning_embedding_out_channels=(16, 32),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
controlnet1.controlnet_down_blocks.apply(init_weights)

torch.manual_seed(0)
controlnet2 = ControlNetModel(
block_out_channels=(32, 64),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
conditioning_embedding_out_channels=(16, 32),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
controlnet2.controlnet_down_blocks.apply(init_weights)

scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

controlnet = MultiControlNetModel([controlnet1, controlnet2])

components = {
"unet": unet,
"controlnet": controlnet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
}
return components

def get_dummy_inputs(self, device, seed=0, img_res=64):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)

# Get random floats in [0, 1] as image
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
mask_image = torch.ones_like(image)

controlnet_embedder_scale_factor = 2
# One conditioning image per ControlNet
control_images = []
for i in range(2):
control_image = (
floats_tensor(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
rng=random.Random(seed + i),
)
.to(device)
.cpu()
.permute(0, 2, 3, 1)[0]
)
control_image = 255 * control_image
control_images.append(Image.fromarray(np.uint8(control_image)).convert("RGB").resize((img_res, img_res)))

# Convert image and mask_image to [0, 255]
image = 255 * image
mask_image = 255 * mask_image
# Convert to PIL image
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((img_res, img_res))
mask_image = Image.fromarray(np.uint8(mask_image)).convert("L").resize((img_res, img_res))

inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
"image": init_image,
"mask_image": mask_image,
"control_image": control_images,
}
return inputs

def test_inference_multiple_prompt_input(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLControlNetInpaintPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

# batched conditioning: two prompts, each carrying one conditioning image per ControlNet
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
inputs["control_image"] = [inputs["control_image"], inputs["control_image"]]
output = sd_pipe(**inputs)
image = output.images
assert image.shape == (2, 64, 64, 3)

image_1, image_2 = image
# the two batch items use different initial noise, so the outputs should differ
assert np.sum(np.abs(image_1 - image_2)) > 1e-3

# a single set of conditioning images is broadcast across multiple prompts
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
output_1 = sd_pipe(**inputs)
assert np.abs(image - output_1.images).max() < 1e-3

# TODO(Patrick, Sayak) - skip for now as this requires more refiner tests
def test_save_load_optional_components(self):
pass
Loading