From 4fe82e94f81420c5d8afa7704a19732f8384b664 Mon Sep 17 00:00:00 2001 From: Abhay Trivedi Date: Thu, 18 Jun 2026 14:14:06 +0530 Subject: [PATCH] Support batched multi-controlnet conditioning in SDXL ControlNet inpaint pipeline Allows passing `control_image` as a list of lists so multiple ControlNets can run over a batch of prompts in StableDiffusionXLControlNetInpaintPipeline, mirroring the support added for StableDiffusionControlNetPipeline in #6334. In check_inputs the nested-list guard is replaced with a transpose that validates each per-controlnet sublist against the existing check_image, and the same transpose is applied in __call__ before the prepare_control_image loop. The conditioning scale guard message is clarified and the now-obsolete prompt-list warning is removed, matching #6334. check_image is left untouched. A fast CPU test covers the batched conditioning case and the single-set broadcast case for two ControlNets. --- .../pipeline_controlnet_inpaint_sd_xl.py | 36 ++-- .../test_controlnet_inpaint_sdxl.py | 201 ++++++++++++++++++ 2 files changed, 222 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index f27fcd8aa26f..f817122d6073 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -760,15 +760,6 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -790,14 +781,20 @@ def check_inputs( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) - - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) else: assert False @@ -816,7 +813,10 @@ def check_inputs( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): @@ -1233,7 +1233,8 @@ def __call__( as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input - to a single ControlNet. + to a single ControlNet. You can also pass a list of lists to run multiple ControlNets over a batch of + prompts, where each inner list holds one conditioning image per ControlNet for one prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -1568,6 +1569,11 @@ def denoising_value_valid(dnv): elif isinstance(controlnet, MultiControlNetModel): control_images = [] + # Nested lists as ControlNet condition + if isinstance(control_image[0], list): + # Transpose the nested image list + control_image = [list(t) for t in zip(*control_image)] + for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py index c91f2c700c15..89d99e898799 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py @@ -36,6 +36,7 @@ StableDiffusionXLControlNetInpaintPipeline, UNet2DConditionModel, ) +from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.utils.import_utils import is_xformers_available from ...testing_utils import ( @@ -354,3 +355,203 @@ def test_save_load_optional_components(self): def test_float16_inference(self): super().test_float16_inference(expected_max_diff=5e-1) + + +class StableDiffusionXLMultiControlNetInpaintPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): + pipeline_class = StableDiffusionXLControlNetInpaintPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = frozenset([]) + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + + def init_weights(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.normal_(m.weight) + m.bias.data.fill_(1.0) + + controlnet1 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + controlnet1.controlnet_down_blocks.apply(init_weights) + + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + controlnet2.controlnet_down_blocks.apply(init_weights) + + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0, img_res=64): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Get random floats in [0, 1] as image + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + mask_image = torch.ones_like(image) + + controlnet_embedder_scale_factor = 2 + # One conditioning image per ControlNet + control_images = [] + for i in range(2): + control_image = ( + floats_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + rng=random.Random(seed + i), + ) + .to(device) + .cpu() + .permute(0, 2, 3, 1)[0] + ) + control_image = 255 * control_image + control_images.append(Image.fromarray(np.uint8(control_image)).convert("RGB").resize((img_res, img_res))) + + # Convert image and mask_image to [0, 255] + image = 255 * image + mask_image = 255 * mask_image + # Convert to PIL image + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((img_res, img_res)) + mask_image = Image.fromarray(np.uint8(mask_image)).convert("L").resize((img_res, img_res)) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + "image": init_image, + "mask_image": mask_image, + "control_image": control_images, + } + return inputs + + def test_inference_multiple_prompt_input(self): + device = "cpu" + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLControlNetInpaintPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # batched conditioning: two prompts, each carrying one conditioning image per ControlNet + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + inputs["control_image"] = [inputs["control_image"], inputs["control_image"]] + output = sd_pipe(**inputs) + image = output.images + assert image.shape == (2, 64, 64, 3) + + image_1, image_2 = image + # the two batch items use different initial noise, so the outputs should differ + assert np.sum(np.abs(image_1 - image_2)) > 1e-3 + + # a single set of conditioning images is broadcast across multiple prompts + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + output_1 = sd_pipe(**inputs) + assert np.abs(image - output_1.images).max() < 1e-3 + + # TODO(Patrick, Sayak) - skip for now as this requires more refiner tests + def test_save_load_optional_components(self): + pass