diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index d08b6c5a5973..b1f62794eaf8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -470,7 +470,11 @@ def encode_prompt( # We are only ALWAYS interested in the pooled output of the final text encoder if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + if clip_skip is None: + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-(clip_skip + 2)] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 19ccfab3de0a..6590abdf2c4e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -488,7 +488,11 @@ def encode_prompt( # We are only ALWAYS interested in the pooled output of the final text encoder if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + if clip_skip is None: + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-(clip_skip + 2)] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 7382d597102c..b3c19e31f1ab 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -592,7 +592,11 @@ def encode_prompt( # We are only ALWAYS interested in the pooled output of the final text encoder if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + if clip_skip is None: + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-(clip_skip + 2)] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index c9afdc3209cd..fb61b9154e71 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -844,6 +844,39 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + def test_encode_prompt_clip_skip_applies_to_negative_embeds(self): + device = "cpu" + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components).to(device) + + prompt = "A painting of a squirrel eating a burger" + negative_prompt = "blurry, low quality" + + # encode without clip_skip + prompt_embeds_no_skip, negative_embeds_no_skip, _, _ = sd_pipe.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + clip_skip=None, + ) + + # encode with clip_skip=1 + prompt_embeds_skip, negative_embeds_skip, _, _ = sd_pipe.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + clip_skip=1, + ) + + # positive embeds already respected clip_skip (sanity check) + assert not torch.allclose(prompt_embeds_no_skip, prompt_embeds_skip) + # negative embeds must also change with clip_skip (this is the bug fix) + assert not torch.allclose(negative_embeds_no_skip, negative_embeds_skip) + def test_stable_diffusion_xl_negative_conditions(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components()