diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6703c9299e80..9347e878c3ba 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -643,6 +643,8 @@ title: Z-Image title: Image - sections: + - local: api/pipelines/diffusion_gemma + title: DiffusionGemma - local: api/pipelines/llada2 title: LLaDA2 title: Text diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md new file mode 100644 index 000000000000..e2c6b0c4178a --- /dev/null +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -0,0 +1,143 @@ + + +# DiffusionGemma + +DiffusionGemma is a block-diffusion encoder-decoder language model. A causal encoder reads the clean prompt (and any +previously generated blocks) into a KV cache, and a bidirectional decoder denoises a fixed-size "canvas" of +`canvas_length` tokens by cross-attending to that cache. Generation alternates an outer autoregressive loop over +canvases with an inner denoising loop, where each step samples candidate tokens, commits the most confident ones via +[`BlockRefinementScheduler`] in uniform corruption mode, and renoises the rest. The model itself lives in +`transformers` as `DiffusionGemmaForBlockDiffusion`. + +## Usage + +```py +import torch +from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion + +from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline + +model_id = "google/diffusiongemma-26B-A4B-it" +model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id) +scheduler = BlockRefinementScheduler() + +pipe = DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor) +output = pipe( + prompt="Why is the sky blue?", + gen_length=256, + num_inference_steps=48, + temperature=0.0, +) +print(output.texts[0]) +``` + +`num_inference_steps` is the number of denoising steps per canvas (48 matches the released checkpoint); fewer steps are +faster but lower quality. For multimodal prompts, pass an `image` alongside the `prompt` (or put the image content in a +raw `messages` conversation), and the processor turns it into the model's image inputs automatically. + +## Schedulers + +The scheduler is the sampler that denoises each canvas, and it is interchangeable: swap it to change the sampling +strategy without touching anything else. Three schedulers are available: + +- `BlockRefinementScheduler` (default): commits the most confident tokens each step (above `threshold`, plus an even + per-step quota) and renoises the rest. `editing_threshold` additionally lets it re-edit already committed tokens. +- `DiscreteDDIMScheduler`: samples each position from the exact discrete posterior of the uniform corruption process + (D3PM). It is parameter free, and the final step deterministically commits the predicted tokens. +- `EntropyBoundScheduler`: commits the lowest-entropy positions whose joint entropy stays under `entropy_bound`, so + roughly independent tokens are accepted together. + +```py +from diffusers import DiscreteDDIMScheduler, EntropyBoundScheduler + +pipe.scheduler = DiscreteDDIMScheduler() +# or: pipe.scheduler = EntropyBoundScheduler(entropy_bound=0.1) +output = pipe(prompt="Why is the sky blue?", gen_length=256, num_inference_steps=48) +print(output.texts[0]) +``` + +Scheduler-specific sampling knobs (the block-refinement `threshold`/`top_k`, the entropy bound, ...) are set on the +scheduler config: + +```py +from diffusers import BlockRefinementScheduler + +pipe.scheduler = BlockRefinementScheduler.from_config(pipe.scheduler.config, threshold=0.9) +``` + +### Predictor-corrector sampling + +`DiscreteDDIMScheduler` supports the leave-one-out predictor-corrector of [Reparameterizing Uniform Diffusion Models](https://huggingface.co/papers/2605.22765). After each predictor step the pipeline runs `corrector_steps` Gibbs sweeps that resample the least-confident positions from the one-coordinate conditional of the noisy marginal, which leaves that marginal invariant and improves generation at no extra training cost. It works directly on the released checkpoint: for uniform diffusion the denoiser and the leave-one-out posterior are interchangeable in closed form, so the corrector recovers the leave-one-out quantities it needs without any retraining. + +```py +from diffusers import DiscreteDDIMScheduler + +pipe.scheduler = DiscreteDDIMScheduler(corrector_steps=2, corrector_k=12) +output = pipe(prompt="Why is the sky blue?", gen_length=256, num_inference_steps=48) +print(output.texts[0]) +``` + +## PEFT adapters + +The denoiser is a 🤗 Transformers model, so adapters are loaded through its native [PEFT](https://huggingface.co/docs/peft) integration rather than the diffusers `load_lora_weights` API. Because that integration is adapter-type-agnostic, the same calls load LoRA, DoRA, or any other PEFT adapter (e.g. the output of TRL's `SFTTrainer`). The pipeline forwards the PEFT API so you can manage adapters from the pipeline directly: + +```py +pipe.load_adapter("path/to/adapter", adapter_name="sft") # LoRA, DoRA, ... +pipe.set_adapter("sft") +output = pipe(prompt="Why is the sky blue?", gen_length=256) + +pipe.disable_adapters() # run the base model +pipe.delete_adapter("sft") +``` + +Adapters stay active and unmerged: DiffusionGemma ties the encoder and decoder base weights, so fusing an adapter into them would corrupt both branches. + +## Static cache and compilation + +The pipeline prefills the encoder once per block into a reusable cache (a `DynamicCache` by default). Pass +`cache_implementation="static"` to use a fixed-shape `StaticCache` instead, whose shapes let you `torch.compile` the +decoder for a further speedup: + +```py +pipe.model.model.decoder = torch.compile(pipe.model.model.decoder, fullgraph=True) +output = pipe(prompt="Why is the sky blue?", gen_length=256, cache_implementation="static") +``` + +## Callbacks + +Callbacks run after each denoising step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are +included in `callback_kwargs`; `canvas` (the current block tokens) and `logits` are available. Return `{"canvas": ...}` +from the callback to replace the canvas. + +```py +def on_step_end(pipe, step, timestep, callback_kwargs): + canvas = callback_kwargs["canvas"] + # Inspect or modify `canvas` here. + return {"canvas": canvas} + + +out = pipe( + prompt="Why is the sky blue?", + callback_on_step_end=on_step_end, + callback_on_step_end_tensor_inputs=["canvas"], +) +``` + +## DiffusionGemmaPipeline +[[autodoc]] DiffusionGemmaPipeline + - all + - __call__ + +## DiffusionGemmaPipelineOutput +[[autodoc]] pipelines.DiffusionGemmaPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6353347503e1..27978c1f3d8b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -385,6 +385,10 @@ "AmusedScheduler", "BlockRefinementScheduler", "BlockRefinementSchedulerOutput", + "DiscreteDDIMScheduler", + "DiscreteDDIMSchedulerOutput", + "EntropyBoundScheduler", + "EntropyBoundSchedulerOutput", "CMStochasticIterativeScheduler", "CogVideoXDDIMScheduler", "CogVideoXDPMScheduler", @@ -572,6 +576,8 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DiffusionGemmaPipeline", + "DiffusionGemmaPipelineOutput", "DreamLiteMobilePipeline", "DreamLitePipeline", "DreamLitePipelineOutput", @@ -1251,11 +1257,15 @@ DDPMScheduler, DDPMWuerstchenScheduler, DEISMultistepScheduler, + DiscreteDDIMScheduler, + DiscreteDDIMSchedulerOutput, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EDMDPMSolverMultistepScheduler, EDMEulerScheduler, + EntropyBoundScheduler, + EntropyBoundSchedulerOutput, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, FlowMapEulerDiscreteScheduler, @@ -1407,6 +1417,8 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DiffusionGemmaPipeline, + DiffusionGemmaPipelineOutput, DreamLiteMobilePipeline, DreamLitePipeline, DreamLitePipelineOutput, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 850a991941ff..e2e3d2ba2f5f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -273,6 +273,7 @@ "IFPipeline", "IFSuperResolutionPipeline", ] + _import_structure["diffusion_gemma"] = ["DiffusionGemmaPipeline", "DiffusionGemmaPipelineOutput"] _import_structure["dreamlite"] = ["DreamLitePipeline", "DreamLiteMobilePipeline", "DreamLitePipelineOutput"] _import_structure["easyanimate"] = [ "EasyAnimatePipeline", @@ -716,6 +717,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) + from .diffusion_gemma import DiffusionGemmaPipeline, DiffusionGemmaPipelineOutput from .dreamlite import ( DreamLiteMobilePipeline, DreamLitePipeline, diff --git a/src/diffusers/pipelines/diffusion_gemma/__init__.py b/src/diffusers/pipelines/diffusion_gemma/__init__.py new file mode 100644 index 000000000000..282efe37a797 --- /dev/null +++ b/src/diffusers/pipelines/diffusion_gemma/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_diffusion_gemma"] = ["DiffusionGemmaPipeline", "DiffusionGemmaPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_diffusion_gemma import DiffusionGemmaPipeline, DiffusionGemmaPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py new file mode 100644 index 000000000000..f7dfd0627d5c --- /dev/null +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -0,0 +1,451 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from tqdm.auto import tqdm +from transformers import DynamicCache, StaticCache + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import BlockRefinementScheduler, DiscreteDDIMScheduler, EntropyBoundScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion + >>> from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline + + >>> model_id = "google/diffusiongemma-26B-A4B-it" + >>> model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> scheduler = BlockRefinementScheduler() + + >>> pipe = DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor) + >>> output = pipe(prompt="Why is the sky blue?", gen_length=256) + >>> print(output.texts[0]) + ``` +""" + + +@dataclass +class DiffusionGemmaPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +class DiffusionGemmaPipeline(DiffusionPipeline): + r""" + Pipeline for DiffusionGemma block-diffusion text generation. + + DiffusionGemma is a block-diffusion encoder-decoder model: a causal encoder reads the clean prompt (and any + previously generated blocks) into a KV cache, and a bidirectional decoder denoises a fixed-size "canvas" of + `canvas_length` tokens by cross-attending to that cache. Generation alternates an outer autoregressive loop over + canvases with an inner denoising loop, where each step samples candidate tokens, commits the most confident ones + via [`BlockRefinementScheduler`] (uniform corruption mode, `mask_token_id=None`), and renoises the rest. + + The model is expected to be a `DiffusionGemmaForBlockDiffusion` instance exposing `forward(input_ids, + decoder_input_ids=..., self_conditioning_logits=..., ...)` and returning logits of shape `[batch, canvas_length, + vocab_size]` over the canvas. + """ + + model: Any + scheduler: BlockRefinementScheduler | DiscreteDDIMScheduler | EntropyBoundScheduler + processor: Any + + _callback_tensor_inputs = ["canvas", "logits"] + + def __init__( + self, + model: Any, + scheduler: BlockRefinementScheduler | DiscreteDDIMScheduler | EntropyBoundScheduler, + processor: Any | None = None, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, processor=processor) + text_config = model.config.get_text_config() + self.canvas_length = model.config.canvas_length + self.vocab_size = text_config.vocab_size + tokenizer = getattr(processor, "tokenizer", processor) + self.eos_token_id = getattr(tokenizer, "eos_token_id", None) if tokenizer is not None else None + + @property + def num_timesteps(self): + return self._num_timesteps + + # --- PEFT adapters --- + # + # The denoiser is a 🤗 Transformers model (a `PeftAdapterMixin`), not a diffusers `ModelMixin`, so the diffusers + # `LoraLoaderMixin` (which targets diffusers components and fuses kohya-format LoRA) does not apply. We instead + # forward the model's native, adapter-type-agnostic PEFT API, so LoRA, DoRA, and other PEFT adapters all load the + # same way. Adapters stay active and unmerged: DiffusionGemma ties the encoder and decoder base weights, so fusing + # would corrupt them. + + def load_adapter(self, *args, **kwargs): + """ + Load a PEFT adapter (LoRA, DoRA, ...) into the underlying model. + + Forwards to [`~transformers.integrations.PeftAdapterMixin.load_adapter`]; the first argument is the path or Hub + id of the adapter (a directory with `adapter_config.json` and the adapter weights, e.g. the output of + [`SFTTrainer`]). + """ + return self.model.load_adapter(*args, **kwargs) + + def set_adapter(self, adapter_name: str | list[str]): + """Activate one or more loaded adapters by name.""" + self.model.set_adapter(adapter_name) + + def enable_adapters(self): + """Enable the attached adapters.""" + self.model.enable_adapters() + + def disable_adapters(self): + """Disable all adapters and run the base model.""" + self.model.disable_adapters() + + def delete_adapter(self, adapter_name: str | list[str]): + """Delete one or more loaded adapters.""" + self.model.delete_adapter(adapter_name) + + def active_adapters(self) -> list[str]: + """Names of the currently active adapters.""" + return self.model.active_adapters() if getattr(self.model, "peft_config", None) else [] + + # --- Prompt encoding --- + + def _prepare_inputs( + self, + *, + prompt: str | list[str] | None, + messages: list[dict] | None, + image: Any | list[Any] | None, + add_generation_prompt: bool, + ) -> tuple[torch.LongTensor, torch.LongTensor, dict[str, torch.Tensor]]: + """Tokenize a raw `prompt` (optionally with an `image`) or a raw `messages` conversation into + `(input_ids, attention_mask, multimodal_inputs)`, where `multimodal_inputs` holds the image tensors the + processor produced for the encoder prefill.""" + + def build_content(text, img): + if img is None: + return text + return [{"type": "image", "image": img}, {"type": "text", "text": text}] + + if messages is None: + if isinstance(prompt, list): + images = image if isinstance(image, list) else [image] * len(prompt) + messages = [[{"role": "user", "content": build_content(p, im)}] for p, im in zip(prompt, images)] + else: + messages = [{"role": "user", "content": build_content(prompt, image)}] + + encoded = self.processor.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + ) + ids = encoded["input_ids"] + mask = encoded.get("attention_mask") + if mask is None: + mask = torch.ones_like(ids, dtype=torch.long) + multimodal_keys = ("pixel_values", "image_position_ids", "mm_token_type_ids") + multimodal_inputs = {k: encoded[k] for k in multimodal_keys if k in encoded} + return ids, mask.to(dtype=torch.long), multimodal_inputs + + def check_inputs( + self, + prompt: str | list[str] | None, + messages: list[dict] | None, + gen_length: int, + num_inference_steps: int, + output_type: str, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + if gen_length <= 0: + raise ValueError(f"`gen_length` must be > 0, got {gen_length}.") + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if prompt is None and messages is None: + raise ValueError("Provide either `prompt` or `messages`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if self.processor is None: + raise ValueError("`processor` is required to encode the prompt.") + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + messages: list[dict] | None = None, + image: Any | list[Any] | None = None, + add_generation_prompt: bool = True, + gen_length: int = 256, + num_inference_steps: int = 48, + temperature: float = 0.0, + cache_implementation: str | None = None, + eos_early_stop: bool = True, + eos_token_id: int | None = None, + generator: torch.Generator | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[Any, int, int, dict], dict] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> DiffusionGemmaPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text with block diffusion. + + Args: + prompt (`str` or `List[str]`, *optional*): + Prompt text, wrapped in a chat template and tokenized by the processor. Provide either this or + `messages`. + messages (`List[Dict]`, *optional*): + A raw chat conversation to encode, e.g. `[{"role": "user", "content": "Hello"}]` or a multi-turn / + multimodal conversation. Use this instead of `prompt` for anything beyond a single user turn. + image (`PIL.Image.Image` or `List`, *optional*): + Image(s) to pair with `prompt` for multimodal generation; the processor turns them into the model's + image inputs. For richer layouts, put the image content directly in `messages`. + add_generation_prompt (`bool`, defaults to `True`): + Whether to add the generation prompt when applying the chat template. + gen_length (`int`, defaults to `256`): + Number of tokens to generate, rounded up to a multiple of the model's `canvas_length`. + num_inference_steps (`int`, defaults to `48`): + Number of denoising steps per canvas. + temperature (`float`, defaults to `0.0`): + Sampling temperature. `0.0` is greedy. Other sampling knobs (e.g. `top_k`, `threshold`) are scheduler + config; set them on the scheduler, e.g. `pipe.scheduler = + BlockRefinementScheduler.from_config(pipe.scheduler.config, top_k=...)`. + cache_implementation (`str`, *optional*): + Set to `"static"` to prefill the encoder once per block into a persistent `StaticCache` and run the + decoder against it with fixed shapes, instead of re-encoding the full sequence on every step. The fixed + shapes also let you compile the decoder, e.g. `pipe.model.model.decoder = + torch.compile(pipe.model.model.decoder, fullgraph=True)`. + eos_early_stop (`bool`, defaults to `True`): + Whether to stop generating further canvases once every sequence has emitted EOS. + eos_token_id (`int`, *optional*): + EOS token ID for early stopping. Falls back to the processor's tokenizer. + generator (`torch.Generator`, *optional*): + RNG for sampling. + output_type (`str`, defaults to `"text"`): + `"text"` decodes sequences into strings (requires a processor); `"seq"` returns token IDs only. + return_dict (`bool`, defaults to `True`): + Whether to return a [`DiffusionGemmaPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback run after each denoising step with signature `callback_on_step_end(self, step, timestep, + callback_kwargs)`. Allowed tensor keys: `canvas`, `logits`. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor keys to pass to the callback. + + Examples: + + Returns: + [`~pipelines.diffusion_gemma.pipeline_diffusion_gemma.DiffusionGemmaPipelineOutput`] or `tuple`: + The generated token IDs (`sequences`) and, for `output_type="text"`, the decoded `texts`. + """ + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["canvas"] + + self.check_inputs( + prompt=prompt, + messages=messages, + gen_length=gen_length, + num_inference_steps=num_inference_steps, + output_type=output_type, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + prompt_ids, prompt_attention_mask, multimodal_inputs = self._prepare_inputs( + prompt=prompt, + messages=messages, + image=image, + add_generation_prompt=add_generation_prompt, + ) + + device = self._execution_device + prompt_ids = prompt_ids.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + multimodal_inputs = {k: v.to(device=device) for k, v in multimodal_inputs.items()} + batch_size, prompt_length = prompt_ids.shape + + if eos_token_id is None: + eos_token_id = self.eos_token_id + + canvas_length = self.canvas_length + num_canvases = (gen_length + canvas_length - 1) // canvas_length + # Only `BlockRefinementScheduler` takes a per-call `block_length`; the DiscreteDDIM/EntropyBound schedulers do + # not, so we pass scheduler-specific kwargs by signature. + set_timesteps_kwargs = {"device": device} + if "block_length" in inspect.signature(self.scheduler.set_timesteps).parameters: + set_timesteps_kwargs["block_length"] = canvas_length + self.scheduler.set_timesteps(num_inference_steps, **set_timesteps_kwargs) + step_param_names = set(inspect.signature(self.scheduler.step).parameters) + self._num_timesteps = num_inference_steps * num_canvases + + cur_input_ids = prompt_ids + cur_attention_mask = prompt_attention_mask + finished = torch.zeros(batch_size, dtype=torch.bool, device=device) + global_step = 0 + + # Encode each block of context once into a reusable KV cache and run the decoder against it, rather than + # re-encoding the whole sequence on every denoising step. The default `DynamicCache` grows with the context; + # `cache_implementation="static"` uses a fixed-shape `StaticCache` so the decoder can be `torch.compile`-d. + use_static_cache = cache_implementation == "static" + text_config = self.model.config.get_text_config(decoder=True) + max_cache_len = prompt_length + num_canvases * canvas_length + if use_static_cache: + past_key_values = StaticCache(config=text_config, max_cache_len=max_cache_len) + else: + past_key_values = DynamicCache(config=text_config) + + progress_bar = tqdm(range(num_canvases), **getattr(self, "_progress_bar_config", {})) + for _ in progress_bar: + cur_len = cur_input_ids.shape[1] + decoder_position_ids = torch.arange(cur_len, cur_len + canvas_length, device=device).unsqueeze(0) + + # Encode the tokens not yet in the cache (the whole prompt on the first block, the last committed canvas + # afterwards), so the decoder reuses the encoder KV cache instead of re-encoding the full sequence. + cached_len = past_key_values.get_seq_length() + self.model.model.encoder( + input_ids=cur_input_ids[:, cached_len:], + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + position_ids=torch.arange(cached_len, cur_len, device=device).unsqueeze(0), + # Image tensors are consumed by the prompt prefill only; later blocks encode text-only canvases. + **(multimodal_inputs if cached_len == 0 else {}), + ) + + # Build the 4D decoder mask once per block (outside any compiled region). A static cache spans its full + # buffer; a dynamic cache spans only the populated length. + cache_buffer_len = max_cache_len if use_static_cache else cur_len + decoder_attention_mask = torch.zeros( + (batch_size, cache_buffer_len + canvas_length), dtype=torch.bool, device=device + ) + decoder_attention_mask[:, :cur_len] = cur_attention_mask.bool() + decoder_attention_mask[:, -canvas_length:] = True + mask_mapping = self.model.model.decoder.create_diffusion_decoder_attention_mask( + config=text_config, + inputs_embeds=torch.empty((batch_size, canvas_length, 0), device=device), + past_key_values=past_key_values, + decoder_attention_mask=decoder_attention_mask, + ) + + # Start from a fully random canvas and denoise it; the scheduler resets its committed state at step 0. + canvas = torch.randint(0, self.vocab_size, (batch_size, canvas_length), device=device, generator=generator) + self_conditioning_logits = None + + for step_idx in range(num_inference_steps): + logits = self.model( + decoder_input_ids=canvas, + past_key_values=past_key_values, + self_conditioning_logits=self_conditioning_logits, + decoder_attention_mask=mask_mapping, + decoder_position_ids=decoder_position_ids, + ).logits + self_conditioning_logits = logits + + # Pass only the kwargs the chosen scheduler accepts, so any of the schedulers can drive the pipeline. + # Per-scheduler sampling knobs (thresholds, top-k, ...) live on the scheduler config, not here. + step_kwargs = {"mask_token_id": None, "temperature": temperature, "generator": generator} + step_kwargs = {k: v for k, v in step_kwargs.items() if k in step_param_names} + scheduler_output = self.scheduler.step( + model_output=logits, timestep=step_idx, sample=canvas, return_dict=True, **step_kwargs + ) + canvas = scheduler_output.prev_sample + + # Predictor-corrector (https://huggingface.co/papers/2605.22765): a scheduler may expose extra corrector + # sweeps via a `corrector_steps` config and a `step_correct` method. Each sweep needs fresh logits on the + # updated canvas, so the model is recomputed here. Skipped on the final step. + corrector_steps = getattr(self.scheduler.config, "corrector_steps", 0) + if corrector_steps and step_idx + 1 < num_inference_steps: + for _ in range(corrector_steps): + corrector_logits = self.model( + decoder_input_ids=canvas, + past_key_values=past_key_values, + self_conditioning_logits=self_conditioning_logits, + decoder_attention_mask=mask_mapping, + decoder_position_ids=decoder_position_ids, + ).logits + canvas = self.scheduler.step_correct( + model_output=corrector_logits, timestep=step_idx, sample=canvas, generator=generator + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) + canvas = callback_outputs.pop("canvas", canvas) + global_step += 1 + + # Append the denoised canvas and extend the context for the next block. + cur_input_ids = torch.cat([cur_input_ids, canvas], dim=-1) + cur_attention_mask = F.pad(cur_attention_mask, (0, canvas_length), value=1) + + if eos_early_stop and eos_token_id is not None: + finished = finished | (canvas == eos_token_id).any(dim=-1) + if finished.all(): + break + + progress_bar.close() + + sequences = cur_input_ids[:, prompt_length:] + + # Trim each row at its first EOS so post-EOS canvas tokens don't leak into the decoded text. + decode_sequences: list[torch.LongTensor] | torch.LongTensor = sequences + if eos_token_id is not None: + decode_sequences = [ + seq[: int((seq == eos_token_id).nonzero(as_tuple=True)[0][0]) + 1] + if (seq == eos_token_id).any() + else seq + for seq in sequences + ] + + texts = None + if output_type == "text" and self.processor is not None: + texts = self.processor.batch_decode(decode_sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return DiffusionGemmaPipelineOutput(sequences=sequences, texts=texts) + + +__all__ = ["DiffusionGemmaPipeline", "DiffusionGemmaPipelineOutput"] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 447586c6f436..b7332480f822 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -41,6 +41,8 @@ _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"] + _import_structure["scheduling_discrete_ddim"] = ["DiscreteDDIMScheduler", "DiscreteDDIMSchedulerOutput"] + _import_structure["scheduling_entropy_bound"] = ["EntropyBoundScheduler", "EntropyBoundSchedulerOutput"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -158,12 +160,14 @@ from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_discrete_ddim import DiscreteDDIMScheduler, DiscreteDDIMSchedulerOutput from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_edm_dpmsolver_multistep import EDMDPMSolverMultistepScheduler from .scheduling_edm_euler import EDMEulerScheduler + from .scheduling_entropy_bound import EntropyBoundScheduler, EntropyBoundSchedulerOutput from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_flow_map_euler_discrete import FlowMapEulerDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_block_refinement.py b/src/diffusers/schedulers/scheduling_block_refinement.py index 3b4d737767ce..2844fb8f0d29 100644 --- a/src/diffusers/schedulers/scheduling_block_refinement.py +++ b/src/diffusers/schedulers/scheduling_block_refinement.py @@ -74,6 +74,8 @@ def __init__( self.num_inference_steps = num_inference_steps self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long) self._transfer_schedule: torch.LongTensor | None = None + # committed positions for the uniform corruption mode (no mask token); reset at the start of each block + self._committed: torch.BoolTensor | None = None def set_timesteps( self, @@ -92,6 +94,7 @@ def set_timesteps( self._transfer_schedule = self.get_num_transfer_tokens(block_length, self.num_inference_steps).to( device=device if device is not None else "cpu" ) + self._committed = None def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor: """Evenly distribute `block_length` token commits across `num_inference_steps` steps.""" @@ -178,7 +181,7 @@ def step( timestep: int | torch.Tensor, sample: torch.LongTensor, *, - mask_token_id: int, + mask_token_id: int | None = None, temperature: float = 0.0, top_p: float | None = None, top_k: int | None = None, @@ -203,9 +206,11 @@ def step( timestep (`int` or `torch.Tensor`): Current step index within the block's refinement schedule. sample (`torch.LongTensor` of shape `(batch_size, block_length)`): - Current block token IDs (contains mask tokens for uncommitted positions). - mask_token_id (`int`): - Token ID used for masked positions. + Current block token IDs (contains mask tokens for uncommitted positions in the mask-based mode). + mask_token_id (`int`, *optional*): + Token ID used for masked positions. When `None`, the scheduler runs in uniform corruption mode: it + tracks committed positions internally (resetting at `timestep == 0`) and renoises the uncommitted ones + with uniformly random tokens, matching DiffusionGemma's block refinement sampler. temperature (`float`): Sampling temperature. top_p (`float`, *optional*): @@ -247,14 +252,54 @@ def step( ) batch_size, block_length = sample.shape - active_block = sample == mask_token_id - masks_remaining = active_block.any() if isinstance(timestep, torch.Tensor): step_index = int(timestep.item()) else: step_index = int(timestep) + # --- Uniform corruption mode (DiffusionGemma): no mask token, committed positions tracked as state --- + if mask_token_id is None: + if step_index == 0 or self._committed is None or self._committed.shape != sample.shape: + self._committed = torch.zeros_like(sample, dtype=torch.bool) + committed = self._committed + confidence = sampled_probs.to(dtype=torch.float32) + + # Cumulative quota: evenly distribute the block across the steps, commit what is still owed + steps_done = step_index + 1 + target = (steps_done * block_length + self.num_inference_steps - 1) // self.num_inference_steps + needed = (target - committed.sum(dim=-1)).clamp(min=0) + + masked_confidence = confidence.masked_fill(committed, float("-inf")) + ranks = masked_confidence.argsort(dim=-1, descending=True).argsort(dim=-1) + transfer_index = ~committed & ((ranks < needed[:, None]) | (confidence > threshold)) + + editing_transfer_index = torch.zeros_like(transfer_index) + if editing_threshold is not None: + editing_transfer_index = ( + committed & (sampled_tokens != sample) & (confidence > float(editing_threshold)) + ) + + prev_sample = torch.where(transfer_index | editing_transfer_index, sampled_tokens, sample) + self._committed = committed | transfer_index + random_tokens = torch.randint( + low=0, high=model_output.shape[-1], size=sample.shape, device=sample.device, generator=generator + ) + prev_sample = torch.where(self._committed, prev_sample, random_tokens) + + if not return_dict: + return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs + return BlockRefinementSchedulerOutput( + prev_sample=prev_sample, + transfer_index=transfer_index, + editing_transfer_index=editing_transfer_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + active_block = sample == mask_token_id + masks_remaining = active_block.any() + # --- Mask-filling transfer --- transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool) if masks_remaining and self._transfer_schedule is not None: diff --git a/src/diffusers/schedulers/scheduling_discrete_ddim.py b/src/diffusers/schedulers/scheduling_discrete_ddim.py new file mode 100644 index 000000000000..daf0abcf47f0 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_discrete_ddim.py @@ -0,0 +1,310 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DiscreteDDIMSchedulerOutput(BaseOutput): + """ + Output class for the discrete DDIM scheduler. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current denoising step. + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Token IDs sampled from the model logits, i.e. the predicted clean tokens `x0`. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class DiscreteDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Discrete DDIM scheduler for the uniform corruption process, following "Structured Denoising Diffusion Models in + Discrete State-Spaces" (D3PM, https://huggingface.co/papers/2107.03006). + + On the linear schedule the survival probability of a clean token at time `t` is `alpha(t) = 1 - t`. One denoising + step from time `t` to `s < t` samples every block position from the exact posterior `q(x_s | x_t, x0)`, which for + the uniform kernel decomposes into three routes: jump to the predicted clean token `x0`, stay on the current token, + or jump to a uniformly random token. Unlike masked diffusion, there is no mask token; uncommitted positions carry + random tokens. + + An optional predictor-corrector mode follows "Reparameterizing Uniform Diffusion Models" via the leave-one-out + (LOO) denoiser (https://huggingface.co/papers/2605.22765). When `corrector_steps > 0`, the pipeline runs that many + Gibbs corrector sweeps after each predictor step (see [`~DiscreteDDIMScheduler.step_correct`]), resampling the + least-confident positions from the one-coordinate conditional `Cat(alpha_s * x0_loo + (1 - alpha_s) / K)` while + holding the rest fixed, which leaves the marginal `p_s` invariant and improves generation at no training cost. + + Args: + num_inference_steps (`int`, defaults to 32): + The number of denoising steps, defining the linear time grid the posterior is evaluated on. + corrector_steps (`int`, defaults to 0): + Number of Gibbs corrector sweeps run after each predictor step. `0` recovers plain ancestral DDIM sampling. + corrector_k (`int`, defaults to 1): + Number of positions resampled per corrector sweep. + corrector_selection (`str`, defaults to `"lowest_log_margin"`): + How the resampled positions are chosen: `"lowest_log_margin"`, `"lowest_maxprob"`, `"lowest_current_prob"`, + or `"random"`. + corrector_selection_tau (`float`, defaults to 1.0): + Temperature of the Gumbel-top-k position selection (lower is greedier). + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_inference_steps: int = 32, + corrector_steps: int = 0, + corrector_k: int = 1, + corrector_selection: str = "lowest_log_margin", + corrector_selection_tau: float = 1.0, + ): + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, device=device, dtype=torch.long) + + @staticmethod + def _sample_from_logits( + logits: torch.Tensor, + *, + temperature: float, + generator: torch.Generator | None, + ) -> tuple[torch.LongTensor, torch.Tensor]: + """Sample one token per position with optional temperature, returning tokens and their probabilities.""" + if temperature < 0: + raise ValueError(f"`temperature` must be >= 0, got {temperature}.") + + vocab_size = logits.shape[-1] + flat_logits = logits.reshape(-1, vocab_size) + probs = torch.softmax(flat_logits.float(), dim=-1) + + if temperature == 0.0: + token = flat_logits.argmax(dim=-1, keepdim=True) + else: + scaled_probs = torch.softmax(flat_logits.float() / temperature, dim=-1) + token = torch.multinomial(scaled_probs, num_samples=1, generator=generator) + + token_prob = torch.gather(probs, -1, token) + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + def _alpha(self, step_index: int) -> float: + """Survival probability `alpha = 1 - t` of a clean token at the time grid point `step_index`.""" + return step_index / self.num_inference_steps + + @staticmethod + def _to_loo_logits(logits: torch.Tensor, tokens: torch.LongTensor, alpha: float) -> torch.Tensor: + """ + Convert plain-denoiser logits to the leave-one-out posterior for the uniform kernel. + + Subtracts `log(1 + K * alpha / (1 - alpha))` from the observed token's logit (eq. 13 of + https://huggingface.co/papers/2605.22765); renormalization happens in the following softmax. + """ + if alpha <= 0.0 or alpha >= 1.0: + return logits + delta = math.log1p(logits.shape[-1] * alpha / (1.0 - alpha)) + shifted = logits.clone() + src = torch.full((*tokens.shape, 1), -delta, dtype=shifted.dtype, device=shifted.device) + shifted.scatter_add_(-1, tokens.unsqueeze(-1), src) + return shifted + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + temperature: float = 0.0, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> DiscreteDDIMSchedulerOutput | tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]: + """ + Sample the next block from the posterior `q(x_s | x_t, x0)` of the uniform corruption process. + + With `a = alpha_t / alpha_s` (survival probability from `s` to `t`) and `b = alpha_s`, the posterior mass of + each route is + + clean: `b * (1 - a) / K + a * b * 1[x_t = x0]`, stay: `a * (1 - b) / K`, noise: `(1 - a) * (1 - b) / K`, + + so the last step (`b = 1`) deterministically commits the predicted clean tokens. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`): + Raw logits from the model for the current block. + timestep (`int` or `torch.Tensor`): + Current step index within the denoising schedule, in `[0, num_inference_steps - 1]`. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs `x_t`. + temperature (`float`): + Sampling temperature applied to the logits when drawing `x0`. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_dict (`bool`): + Whether to return a [`DiscreteDDIMSchedulerOutput`] or a plain tuple. + """ + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + sampled_tokens, sampled_probs = self._sample_from_logits( + model_output, temperature=temperature, generator=generator + ) + + vocab_size = model_output.shape[-1] + num_steps = self.num_inference_steps + # `step_index` counts up from 0 to `num_inference_steps - 1`: alpha(t) = 1 - t increases towards the clean end, + # with alpha_s = 1 on the final step so the predicted clean tokens are committed deterministically. + alpha_t = step_index / num_steps + alpha_s = (step_index + 1) / num_steps + survival = alpha_t / alpha_s + + same = (sample == sampled_tokens).float() + clean_mass = alpha_s * (1 - survival) / vocab_size + survival * alpha_s * same + stay_mass = survival * (1 - alpha_s) / vocab_size * torch.ones_like(same) + noise_mass = (1 - survival) * (1 - alpha_s) / vocab_size * torch.ones_like(same) + + route_probs = torch.stack([clean_mass, stay_mass, noise_mass], dim=-1) + route_probs = route_probs / route_probs.sum(dim=-1, keepdim=True) + routes = torch.multinomial(route_probs.view(-1, 3), num_samples=1, generator=generator).view_as(sample) + + random_tokens = torch.randint( + low=0, high=vocab_size, size=sample.shape, device=sample.device, generator=generator + ) + prev_sample = torch.where(routes == 0, sampled_tokens, sample) + prev_sample = torch.where(routes == 2, random_tokens, prev_sample) + + if not return_dict: + return prev_sample, sampled_tokens, sampled_probs + return DiscreteDDIMSchedulerOutput( + prev_sample=prev_sample, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + def _select_positions( + self, sample: torch.LongTensor, cond_log_probs: torch.Tensor, generator: torch.Generator | None + ) -> torch.LongTensor: + """Pick `corrector_k` positions per row to resample, least-confident first (Gumbel-top-k without replacement).""" + selection = self.config.corrector_selection + batch_size, seq_len = sample.shape + k_eff = min(max(1, int(self.config.corrector_k)), seq_len) + + if selection == "random": + scores = torch.rand(batch_size, seq_len, device=sample.device, generator=generator) + return torch.topk(scores, k=k_eff, dim=-1).indices + + if selection == "lowest_maxprob": + confidence = -cond_log_probs.max(dim=-1).values + elif selection == "lowest_current_prob": + confidence = -torch.gather(cond_log_probs, -1, sample.unsqueeze(-1)).squeeze(-1) + elif selection == "lowest_log_margin": + log_current = torch.gather(cond_log_probs, -1, sample.unsqueeze(-1)).squeeze(-1) + alt = cond_log_probs.clone().scatter_(-1, sample.unsqueeze(-1), float("-inf")) + confidence = -(log_current - alt.max(dim=-1).values) + else: + raise ValueError(f"Unknown `corrector_selection`: {selection!r}.") + + keys = confidence / float(self.config.corrector_selection_tau) + u = torch.rand(keys.shape, device=keys.device, generator=generator).clamp_(1e-12, 1.0 - 1e-12) + keys = keys + (-torch.log(-torch.log(u))) + return torch.topk(keys, k=k_eff, dim=-1).indices + + def step_correct( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> DiscreteDDIMSchedulerOutput | tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]: + """ + Run one Gibbs corrector sweep at the post-predictor time `s`, following the leave-one-out predictor-corrector + of https://huggingface.co/papers/2605.22765. + + The model logits (recomputed on the current `sample`) are converted to the LOO denoiser, the one-coordinate + conditional `p_s(x^l | x^{-l}) = Cat(alpha_s * x0_loo + (1 - alpha_s) / K)` is formed, the least-confident + `corrector_k` positions are selected, and those positions are resampled while the rest are held fixed. The + sweep preserves `p_s`, so it refines the sample without changing its marginal and needs no extra training. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`): + Raw logits from the model recomputed on the current (post-predictor) `sample`. + timestep (`int` or `torch.Tensor`): + The predictor step index just completed; the corrector runs at the following grid point `s`. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs to refine. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_dict (`bool`): + Whether to return a [`DiscreteDDIMSchedulerOutput`] or a plain tuple. + """ + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + # The corrector acts at the cleaner time `s` reached by the predictor. + alpha_s = self._alpha(step_index + 1) + vocab_size = model_output.shape[-1] + + # Match the reference corrector, which forms the conditional in float64 (the LOO correction reaches ~log(K)). + loo_logits = self._to_loo_logits(model_output.double(), sample, alpha_s) + loo_log_probs = torch.log_softmax(loo_logits, dim=-1) + log_uniform = math.log1p(-alpha_s) - math.log(vocab_size) + cond_log_probs = torch.logaddexp( + math.log(alpha_s) + loo_log_probs, torch.full_like(loo_log_probs, log_uniform) + ) + + positions = self._select_positions(sample, cond_log_probs, generator) + rows = torch.arange(sample.shape[0], device=sample.device).unsqueeze(-1).expand_as(positions) + chosen_probs = cond_log_probs[rows, positions].exp() + resampled = torch.multinomial( + chosen_probs.reshape(-1, vocab_size), num_samples=1, generator=generator + ).view_as(positions) + + prev_sample = sample.clone() + prev_sample[rows, positions] = resampled + sampled_probs = torch.gather(chosen_probs, -1, resampled.unsqueeze(-1)).squeeze(-1) + + if not return_dict: + return prev_sample, resampled, sampled_probs + return DiscreteDDIMSchedulerOutput( + prev_sample=prev_sample, + sampled_tokens=resampled, + sampled_probs=sampled_probs, + ) + + +__all__ = ["DiscreteDDIMScheduler", "DiscreteDDIMSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_entropy_bound.py b/src/diffusers/schedulers/scheduling_entropy_bound.py new file mode 100644 index 000000000000..3a6bb50e2f7d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_entropy_bound.py @@ -0,0 +1,166 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class EntropyBoundSchedulerOutput(BaseOutput): + """ + Output class for the entropy bound scheduler. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current denoising step. + accepted_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask of the positions accepted (committed) in this step. + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Token IDs sampled from the model logits. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + accepted_index: torch.BoolTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class EntropyBoundScheduler(SchedulerMixin, ConfigMixin): + """ + Entropy bound scheduler for the uniform corruption process. + + At each step the scheduler samples a candidate token per position and accepts the `k` lowest-entropy positions such + that `sum_i^k entropy_i - max(entropy_1, ..., entropy_k) <= entropy_bound`. The left-hand side upper-bounds the + joint mutual information between the accepted tokens, so they are approximately independent. Accepted positions + keep their sampled token; the rest are renoised with uniformly random tokens (there is no mask token). + + Proposed in "Beyond Next-Token Prediction" (https://huggingface.co/papers/2505.24857). + + Args: + entropy_bound (`float`, defaults to 0.1): + The maximum tolerated joint entropy of the accepted tokens. Larger values accept more tokens per step. + num_inference_steps (`int`, defaults to 32): + The maximum number of denoising steps. + """ + + order = 1 + + @register_to_config + def __init__(self, entropy_bound: float = 0.1, num_inference_steps: int = 32): + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, device=device, dtype=torch.long) + + @staticmethod + def _sample_from_logits( + logits: torch.Tensor, + *, + temperature: float, + generator: torch.Generator | None, + ) -> tuple[torch.LongTensor, torch.Tensor]: + """Sample one token per position with optional temperature, returning tokens and their probabilities.""" + if temperature < 0: + raise ValueError(f"`temperature` must be >= 0, got {temperature}.") + + vocab_size = logits.shape[-1] + flat_logits = logits.reshape(-1, vocab_size) + probs = torch.softmax(flat_logits.float(), dim=-1) + + if temperature == 0.0: + token = flat_logits.argmax(dim=-1, keepdim=True) + else: + scaled_probs = torch.softmax(flat_logits.float() / temperature, dim=-1) + token = torch.multinomial(scaled_probs, num_samples=1, generator=generator) + + token_prob = torch.gather(probs, -1, token) + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + entropy_bound: float | None = None, + temperature: float = 1.0, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> EntropyBoundSchedulerOutput | tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]: + """ + Accept the lowest-entropy positions under the entropy bound and renoise the rest. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`): + Raw logits from the model for the current block. + timestep (`int` or `torch.Tensor`): + Current step index within the denoising schedule. Unused; kept for API consistency. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs. + entropy_bound (`float`, *optional*): + Overrides the configured entropy bound for this step. + temperature (`float`): + Sampling temperature applied to the logits when drawing the candidate tokens. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_dict (`bool`): + Whether to return an [`EntropyBoundSchedulerOutput`] or a plain tuple. + """ + if entropy_bound is None: + entropy_bound = float(self.config.entropy_bound) + + sampled_tokens, sampled_probs = self._sample_from_logits( + model_output, temperature=temperature, generator=generator + ) + + token_entropy = torch.distributions.Categorical(logits=model_output).entropy() # (batch, block_length) + sorted_token_entropy, sorted_indices = torch.sort(token_entropy, dim=-1, descending=False) + cumulative_entropy = torch.cumsum(sorted_token_entropy, dim=-1) + + # `sorted_token_entropy` is the running maximum entropy (ascending order), so the left-hand side bounds the + # joint mutual information of the accepted tokens. + sorted_accepted = cumulative_entropy - sorted_token_entropy <= entropy_bound + accepted_index = torch.scatter( + input=torch.zeros_like(sorted_accepted), dim=-1, index=sorted_indices, src=sorted_accepted + ) + + random_tokens = torch.randint( + low=0, high=model_output.shape[-1], size=sample.shape, device=sample.device, generator=generator + ) + prev_sample = torch.where(accepted_index, sampled_tokens, random_tokens) + + if not return_dict: + return prev_sample, accepted_index, sampled_tokens, sampled_probs + return EntropyBoundSchedulerOutput( + prev_sample=prev_sample, + accepted_index=accepted_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + +__all__ = ["EntropyBoundScheduler", "EntropyBoundSchedulerOutput"] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8439a2b93371..66dde8da8728 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -3032,6 +3032,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DiscreteDDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DiscreteDDIMSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -3107,6 +3137,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class EntropyBoundScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EntropyBoundSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EulerAncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0747e76cf715..e96728d568fd 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1457,6 +1457,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class DiffusionGemmaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DiffusionGemmaPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class DreamLiteMobilePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/diffusion_gemma/__init__.py b/tests/pipelines/diffusion_gemma/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py new file mode 100644 index 000000000000..4110063a70eb --- /dev/null +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -0,0 +1,219 @@ +import unittest + +import torch + +from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline +from diffusers.utils.testing_utils import require_peft_backend, require_peft_version_greater + + +# --- Lightweight stand-in for input-validation tests that never reach the model --- + + +class _DummyTextConfig: + def __init__(self, vocab_size: int): + self.vocab_size = int(vocab_size) + self.eos_token_id = None + + +class _DummyConfig: + def __init__(self, canvas_length: int, vocab_size: int): + self.canvas_length = int(canvas_length) + self._text_config = _DummyTextConfig(vocab_size) + + def get_text_config(self, decoder: bool = False): + return self._text_config + + +class _DummyModel(torch.nn.Module): + def __init__(self, vocab_size: int = 32, canvas_length: int = 8): + super().__init__() + self.config = _DummyConfig(canvas_length, vocab_size) + + +def _make_dummy_pipeline(processor=None, canvas_length: int = 8): + model = _DummyModel(vocab_size=32, canvas_length=canvas_length) + return DiffusionGemmaPipeline(model=model, scheduler=BlockRefinementScheduler(), processor=processor) + + +class DiffusionGemmaPipelineInputTest(unittest.TestCase): + """Input validation and prompt encoding, which short-circuit before the model is called.""" + + def test_no_inputs_raises(self): + pipe = _make_dummy_pipeline() + with self.assertRaises(ValueError): + pipe(gen_length=8, num_inference_steps=2, output_type="seq") + + def test_output_type_invalid_raises(self): + pipe = _make_dummy_pipeline() + with self.assertRaises(ValueError): + pipe(prompt="hi", gen_length=8, output_type="invalid") + + def test_prompt_and_messages_together_raises(self): + pipe = _make_dummy_pipeline() + with self.assertRaises(ValueError): + pipe(prompt="hi", messages=[{"role": "user", "content": "hi"}], gen_length=8, output_type="seq") + + +# --- End-to-end generation: the prefill-once path drives the real encoder/decoder, so it needs the tiny model --- + +_MODEL_ID = "trl-internal-testing/tiny-DiffusionGemmaForBlockDiffusion" + + +def _load_pipeline(test): + try: + from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion + except ImportError as e: + test.skipTest(f"transformers without DiffusionGemma: {e}") + try: + model = DiffusionGemmaForBlockDiffusion.from_pretrained(_MODEL_ID, dtype=torch.float32).eval() + processor = AutoProcessor.from_pretrained(_MODEL_ID) + except Exception as e: # noqa: BLE001 - offline / hub errors should skip, not fail + test.skipTest(f"tiny DiffusionGemma checkpoint unavailable: {e}") + pipe = DiffusionGemmaPipeline(model=model, scheduler=BlockRefinementScheduler(), processor=processor) + pipe.set_progress_bar_config(disable=True) + return pipe, model.config.canvas_length + + +class DiffusionGemmaPipelineTest(unittest.TestCase): + def setUp(self): + self.pipe, self.canvas_length = _load_pipeline(self) + self.prompt = "Name a color." + + def test_generate_seq_shape(self): + out = self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length * 2, + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length * 2)) + self.assertIsNone(out.texts) + + def test_generate_text_and_return_tuple(self): + sequences, texts = self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length, + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="text", + return_dict=False, + ) + self.assertEqual(sequences.shape, (1, self.canvas_length)) + self.assertEqual(len(texts), 1) + + def test_callback_receives_advertised_keys(self): + observed: list[str] = [] + + def callback(pipe, step, timestep, callback_kwargs): + observed.extend(sorted(callback_kwargs.keys())) + return {} + + keys = list(self.pipe._callback_tensor_inputs) + self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length, + num_inference_steps=2, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=keys, + ) + self.assertEqual(set(observed), set(keys)) + + def test_generate_with_image(self): + import numpy as np + from PIL import Image + + image = Image.fromarray((np.random.rand(64, 64, 3) * 255).astype("uint8")) + out = self.pipe( + prompt="What?", + image=image, + gen_length=self.canvas_length, + num_inference_steps=2, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + + def test_schedulers_are_interchangeable(self): + from diffusers import DiscreteDDIMScheduler, EntropyBoundScheduler + + for scheduler in (DiscreteDDIMScheduler(), EntropyBoundScheduler(entropy_bound=0.1)): + self.pipe.scheduler = scheduler + out = self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length, + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + + def test_predictor_corrector_sampling(self): + from diffusers import DiscreteDDIMScheduler + + self.pipe.scheduler = DiscreteDDIMScheduler(corrector_steps=2, corrector_k=2) + out = self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length, + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + + @require_peft_backend + @require_peft_version_greater("0.18.9") + def test_peft_adapter_api(self): + from peft import LoraConfig + + self.assertEqual(self.pipe.active_adapters(), []) + + # The forwarded API is adapter-type-agnostic; LoRA stands in for any PEFT adapter (DoRA, IA3, ...). + self.pipe.model.add_adapter( + LoraConfig(r=4, lora_alpha=8, lora_dropout=0.0, target_modules="all-linear"), + adapter_name="test", + ) + self.pipe.set_adapter("test") + self.assertIn("test", self.pipe.active_adapters()) + + out = self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length, + num_inference_steps=2, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + + self.pipe.disable_adapters() + self.pipe.enable_adapters() + self.pipe.delete_adapter("test") + self.assertEqual(self.pipe.active_adapters(), []) + + def test_static_cache_matches_dynamic(self): + kwargs = { + "prompt": self.prompt, + "gen_length": self.canvas_length * 2, # two canvases -> exercises the cache extension between blocks + "num_inference_steps": 4, + "temperature": 0.0, + "eos_early_stop": False, + "output_type": "seq", + } + dynamic = self.pipe(generator=torch.Generator().manual_seed(0), **kwargs).sequences + static = self.pipe( + generator=torch.Generator().manual_seed(0), cache_implementation="static", **kwargs + ).sequences + self.assertTrue(torch.equal(dynamic, static)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_block_refinement.py b/tests/schedulers/test_scheduler_block_refinement.py index 2e5e404e5f9a..673c38890516 100644 --- a/tests/schedulers/test_scheduler_block_refinement.py +++ b/tests/schedulers/test_scheduler_block_refinement.py @@ -466,5 +466,61 @@ def test_negative_temperature_raises(self): ) +class BlockRefinementSchedulerUniformTest(unittest.TestCase): + """Tests for the uniform corruption mode (`mask_token_id=None`), matching DiffusionGemma's block refinement.""" + + def get_scheduler(self, **kwargs): + config = {"block_length": 256, "num_inference_steps": 48, "threshold": 1.0, "editing_threshold": None} + config.update(kwargs) + scheduler = BlockRefinementScheduler(**config) + scheduler.set_timesteps(config["num_inference_steps"], block_length=config["block_length"]) + return scheduler + + def test_cumulative_quota_progression(self): + # threshold=1.0 disables threshold commits, so only the even per-step quota applies: ceil(256/48)=6, then 11. + scheduler = self.get_scheduler() + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + out0 = scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None) + self.assertEqual(scheduler._committed.sum().item(), 6) + scheduler.step(logits, timestep=1, sample=out0.prev_sample, mask_token_id=None) + self.assertEqual(scheduler._committed.sum().item(), 11) + + def test_last_step_commits_all(self): + scheduler = self.get_scheduler() + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None) + scheduler.step(logits, timestep=47, sample=sample, mask_token_id=None) + self.assertTrue(scheduler._committed.all()) + + def test_threshold_commits_beyond_quota(self): + scheduler = self.get_scheduler(threshold=0.5) + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, torch.arange(20), 0] = 1e6 # 20 high-confidence positions (token 0) + scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None, temperature=0.0) + # 20 positions exceed the threshold and get committed regardless of the quota + self.assertEqual(scheduler._committed.sum().item(), 20) + + def test_editing_replaces_committed_token(self): + scheduler = self.get_scheduler(threshold=1.0, editing_threshold=0.5) + sample = torch.zeros(1, 256, dtype=torch.long) + scheduler._committed = torch.ones_like(sample, dtype=torch.bool) # pretend all committed + logits = torch.zeros(1, 256, 10000) + logits[0, 0, 1] = 1e6 # confidently predicts token 1 at position 0 (differs from current token 0) + out = scheduler.step(logits, timestep=24, sample=sample, mask_token_id=None, temperature=0.0) + self.assertEqual(out.prev_sample[0, 0].item(), 1) + self.assertTrue((out.prev_sample[0, 1:] == 0).all()) + + def test_reset_on_new_block(self): + scheduler = self.get_scheduler() + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + scheduler.step(logits, timestep=5, sample=sample, mask_token_id=None) + scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None) # new block resets committed + self.assertEqual(scheduler._committed.sum().item(), 6) + + if __name__ == "__main__": unittest.main() diff --git a/tests/schedulers/test_scheduler_discrete_ddim.py b/tests/schedulers/test_scheduler_discrete_ddim.py new file mode 100644 index 000000000000..ce5ed47ab13b --- /dev/null +++ b/tests/schedulers/test_scheduler_discrete_ddim.py @@ -0,0 +1,106 @@ +import unittest + +import torch + +from diffusers import DiscreteDDIMScheduler + + +class DiscreteDDIMSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = {"num_inference_steps": 8} + config.update(kwargs) + return DiscreteDDIMScheduler(**config) + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(16) + self.assertEqual(scheduler.num_inference_steps, 16) + self.assertEqual(len(scheduler.timesteps), 16) + self.assertEqual(scheduler.timesteps[0].item(), 0) + self.assertEqual(scheduler.timesteps[-1].item(), 15) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + + def test_last_step_commits_predicted_tokens(self): + # On the final step alpha_s = 1, so the posterior deterministically commits the sampled clean tokens. + n = 8 + scheduler = self.get_scheduler(num_inference_steps=n) + scheduler.set_timesteps(n) + sample = torch.randint(0, 100, (2, 16)) + logits = torch.zeros(2, 16, 100) + out = scheduler.step(logits, timestep=n - 1, sample=sample, temperature=0.0) + self.assertTrue(torch.equal(out.prev_sample, out.sampled_tokens)) + + def test_intermediate_step_keeps_agreeing_positions(self): + # Where the prediction agrees with the current token, almost all posterior mass is on the clean route. + n = 8 + scheduler = self.get_scheduler(num_inference_steps=n) + scheduler.set_timesteps(n) + sample = torch.randint(0, 100, (1, 256)) + logits = torch.zeros(1, 256, 100) + # argmax of zero logits is token 0; make the sample already equal token 0 everywhere + sample = torch.zeros_like(sample) + out = scheduler.step(logits, timestep=n // 2, sample=sample, temperature=0.0) + kept = (out.prev_sample == sample).sum().item() + self.assertGreaterEqual(kept, 250) + + def test_step_output_shapes(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (3, 16)) + logits = torch.randn(3, 16, 100) + out = scheduler.step(logits, timestep=2, sample=sample, temperature=1.0) + self.assertEqual(out.prev_sample.shape, sample.shape) + self.assertEqual(out.sampled_tokens.shape, sample.shape) + self.assertEqual(out.sampled_probs.shape, sample.shape) + + def test_return_tuple(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (1, 16)) + logits = torch.randn(1, 16, 100) + out = scheduler.step(logits, timestep=2, sample=sample, return_dict=False) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 3) + + def test_to_loo_only_shifts_observed_token(self): + # The denoiser->LOO conversion moves only the observed token's logit at each position (eq. 13). + scheduler = self.get_scheduler() + sample = torch.randint(0, 100, (2, 16)) + logits = torch.randn(2, 16, 100) + loo = scheduler._to_loo_logits(logits, sample, alpha=0.4) + diff = loo - logits + moved = diff.abs() > 0 + self.assertTrue(torch.equal(moved.sum(dim=-1), torch.ones(2, 16, dtype=torch.long))) + + def test_step_correct_output_shapes(self): + scheduler = self.get_scheduler(corrector_steps=1, corrector_k=4) + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (3, 16)) + logits = torch.randn(3, 16, 100) + out = scheduler.step_correct(logits, timestep=2, sample=sample) + self.assertEqual(out.prev_sample.shape, sample.shape) + self.assertEqual(out.prev_sample.dtype, sample.dtype) + + def test_step_correct_resamples_at_most_k(self): + # A corrector sweep holds all but `corrector_k` positions per row fixed. + k = 3 + scheduler = self.get_scheduler(corrector_steps=1, corrector_k=k) + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (4, 16)) + logits = torch.randn(4, 16, 100) + out = scheduler.step_correct(logits, timestep=2, sample=sample) + changed = (out.prev_sample != sample).sum(dim=-1) + self.assertTrue(torch.all(changed <= k)) + + def test_step_correct_return_tuple(self): + scheduler = self.get_scheduler(corrector_steps=1) + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (1, 16)) + logits = torch.randn(1, 16, 100) + out = scheduler.step_correct(logits, timestep=2, sample=sample, return_dict=False) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 3) diff --git a/tests/schedulers/test_scheduler_entropy_bound.py b/tests/schedulers/test_scheduler_entropy_bound.py new file mode 100644 index 000000000000..c5bd8560d57d --- /dev/null +++ b/tests/schedulers/test_scheduler_entropy_bound.py @@ -0,0 +1,56 @@ +import unittest + +import torch + +from diffusers import EntropyBoundScheduler + + +class EntropyBoundSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = {"entropy_bound": 0.1, "num_inference_steps": 8} + config.update(kwargs) + return EntropyBoundScheduler(**config) + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(16) + self.assertEqual(scheduler.num_inference_steps, 16) + self.assertEqual(len(scheduler.timesteps), 16) + + def test_zero_entropy_positions_accepted(self): + # Positions with a near-one probability have ~zero entropy and must be accepted. + scheduler = self.get_scheduler(entropy_bound=0.1) + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, :9, 0] = 1e6 # 9 zero-entropy positions + out = scheduler.step(logits, timestep=0, sample=sample, temperature=0.0) + self.assertGreaterEqual(out.accepted_index.sum().item(), 9) + # accepted positions hold the sampled token (token 0 here) + self.assertTrue((out.prev_sample[0, :9] == 0).all()) + + def test_higher_bound_accepts_at_least_as_many(self): + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, 0, 0] = 1.8e1 + logits[0, 1, 1] = 1.45e1 + logits[0, 2, 2] = 1.45e1 + low = self.get_scheduler(entropy_bound=1e-2).step(logits, 0, sample, temperature=0.0) + high = self.get_scheduler(entropy_bound=1e-1).step(logits, 0, sample, temperature=0.0) + self.assertGreaterEqual(high.accepted_index.sum().item(), low.accepted_index.sum().item()) + + def test_non_accepted_are_renoised(self): + scheduler = self.get_scheduler(entropy_bound=1e-3) + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, :5, 0] = 1e6 + out = scheduler.step(logits, timestep=0, sample=sample, temperature=0.0) + # the 5 accepted positions hold token 0, the rest are random (not token 0 almost surely) + self.assertTrue((out.prev_sample[0, :5] == 0).all()) + + def test_step_output_shapes(self): + scheduler = self.get_scheduler() + sample = torch.randint(0, 100, (3, 16)) + logits = torch.randn(3, 16, 100) + out = scheduler.step(logits, timestep=0, sample=sample, temperature=1.0) + self.assertEqual(out.prev_sample.shape, sample.shape) + self.assertEqual(out.accepted_index.shape, sample.shape)