From ef0b135db190af0630b1e429aad6cea20c1cf9b3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 14 Jun 2026 19:07:36 +0200 Subject: [PATCH 01/17] Add discrete DDIM and entropy bound schedulers and a uniform mode for block refinement --- src/diffusers/__init__.py | 8 + src/diffusers/schedulers/__init__.py | 4 + .../schedulers/scheduling_block_refinement.py | 57 +++++- .../schedulers/scheduling_discrete_ddim.py | 173 ++++++++++++++++++ .../schedulers/scheduling_entropy_bound.py | 166 +++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 60 ++++++ .../test_scheduler_block_refinement.py | 56 ++++++ .../test_scheduler_discrete_ddim.py | 67 +++++++ .../test_scheduler_entropy_bound.py | 56 ++++++ 9 files changed, 641 insertions(+), 6 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_discrete_ddim.py create mode 100644 src/diffusers/schedulers/scheduling_entropy_bound.py create mode 100644 tests/schedulers/test_scheduler_discrete_ddim.py create mode 100644 tests/schedulers/test_scheduler_entropy_bound.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index da77fa67df52..e5b36731b5ba 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -385,6 +385,10 @@ "AmusedScheduler", "BlockRefinementScheduler", "BlockRefinementSchedulerOutput", + "DiscreteDDIMScheduler", + "DiscreteDDIMSchedulerOutput", + "EntropyBoundScheduler", + "EntropyBoundSchedulerOutput", "CMStochasticIterativeScheduler", "CogVideoXDDIMScheduler", "CogVideoXDPMScheduler", @@ -1240,6 +1244,10 @@ AmusedScheduler, BlockRefinementScheduler, BlockRefinementSchedulerOutput, + DiscreteDDIMScheduler, + DiscreteDDIMSchedulerOutput, + EntropyBoundScheduler, + EntropyBoundSchedulerOutput, CMStochasticIterativeScheduler, CogVideoXDDIMScheduler, CogVideoXDPMScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 447586c6f436..440f0b91ded9 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -41,6 +41,8 @@ _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"] + _import_structure["scheduling_discrete_ddim"] = ["DiscreteDDIMScheduler", "DiscreteDDIMSchedulerOutput"] + _import_structure["scheduling_entropy_bound"] = ["EntropyBoundScheduler", "EntropyBoundSchedulerOutput"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -148,6 +150,8 @@ from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput + from .scheduling_discrete_ddim import DiscreteDDIMScheduler, DiscreteDDIMSchedulerOutput + from .scheduling_entropy_bound import EntropyBoundScheduler, EntropyBoundSchedulerOutput from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler diff --git a/src/diffusers/schedulers/scheduling_block_refinement.py b/src/diffusers/schedulers/scheduling_block_refinement.py index 3b4d737767ce..2a7c73769755 100644 --- a/src/diffusers/schedulers/scheduling_block_refinement.py +++ b/src/diffusers/schedulers/scheduling_block_refinement.py @@ -74,6 +74,8 @@ def __init__( self.num_inference_steps = num_inference_steps self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long) self._transfer_schedule: torch.LongTensor | None = None + # committed positions for the uniform corruption mode (no mask token); reset at the start of each block + self._committed: torch.BoolTensor | None = None def set_timesteps( self, @@ -92,6 +94,7 @@ def set_timesteps( self._transfer_schedule = self.get_num_transfer_tokens(block_length, self.num_inference_steps).to( device=device if device is not None else "cpu" ) + self._committed = None def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor: """Evenly distribute `block_length` token commits across `num_inference_steps` steps.""" @@ -178,7 +181,7 @@ def step( timestep: int | torch.Tensor, sample: torch.LongTensor, *, - mask_token_id: int, + mask_token_id: int | None = None, temperature: float = 0.0, top_p: float | None = None, top_k: int | None = None, @@ -203,9 +206,11 @@ def step( timestep (`int` or `torch.Tensor`): Current step index within the block's refinement schedule. sample (`torch.LongTensor` of shape `(batch_size, block_length)`): - Current block token IDs (contains mask tokens for uncommitted positions). - mask_token_id (`int`): - Token ID used for masked positions. + Current block token IDs (contains mask tokens for uncommitted positions in the mask-based mode). + mask_token_id (`int`, *optional*): + Token ID used for masked positions. When `None`, the scheduler runs in uniform corruption mode: it + tracks committed positions internally (resetting at `timestep == 0`) and renoises the uncommitted + ones with uniformly random tokens, matching DiffusionGemma's block refinement sampler. temperature (`float`): Sampling temperature. top_p (`float`, *optional*): @@ -247,14 +252,54 @@ def step( ) batch_size, block_length = sample.shape - active_block = sample == mask_token_id - masks_remaining = active_block.any() if isinstance(timestep, torch.Tensor): step_index = int(timestep.item()) else: step_index = int(timestep) + # --- Uniform corruption mode (DiffusionGemma): no mask token, committed positions tracked as state --- + if mask_token_id is None: + if step_index == 0 or self._committed is None or self._committed.shape != sample.shape: + self._committed = torch.zeros_like(sample, dtype=torch.bool) + committed = self._committed + confidence = sampled_probs.to(dtype=torch.float32) + + # Cumulative quota: evenly distribute the block across the steps, commit what is still owed + steps_done = step_index + 1 + target = (steps_done * block_length + self.num_inference_steps - 1) // self.num_inference_steps + needed = (target - committed.sum(dim=-1)).clamp(min=0) + + masked_confidence = confidence.masked_fill(committed, float("-inf")) + ranks = masked_confidence.argsort(dim=-1, descending=True).argsort(dim=-1) + transfer_index = ~committed & ((ranks < needed[:, None]) | (confidence > threshold)) + + editing_transfer_index = torch.zeros_like(transfer_index) + if editing_threshold is not None: + editing_transfer_index = ( + committed & (sampled_tokens != sample) & (confidence > float(editing_threshold)) + ) + + prev_sample = torch.where(transfer_index | editing_transfer_index, sampled_tokens, sample) + self._committed = committed | transfer_index + random_tokens = torch.randint( + low=0, high=model_output.shape[-1], size=sample.shape, device=sample.device, generator=generator + ) + prev_sample = torch.where(self._committed, prev_sample, random_tokens) + + if not return_dict: + return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs + return BlockRefinementSchedulerOutput( + prev_sample=prev_sample, + transfer_index=transfer_index, + editing_transfer_index=editing_transfer_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + active_block = sample == mask_token_id + masks_remaining = active_block.any() + # --- Mask-filling transfer --- transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool) if masks_remaining and self._transfer_schedule is not None: diff --git a/src/diffusers/schedulers/scheduling_discrete_ddim.py b/src/diffusers/schedulers/scheduling_discrete_ddim.py new file mode 100644 index 000000000000..51a890da28a7 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_discrete_ddim.py @@ -0,0 +1,173 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DiscreteDDIMSchedulerOutput(BaseOutput): + """ + Output class for the discrete DDIM scheduler. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current denoising step. + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Token IDs sampled from the model logits, i.e. the predicted clean tokens `x0`. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class DiscreteDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Discrete DDIM scheduler for the uniform corruption process, following "Structured Denoising Diffusion Models in + Discrete State-Spaces" (D3PM, https://huggingface.co/papers/2107.03006). + + On the linear schedule the survival probability of a clean token at time `t` is `alpha(t) = 1 - t`. One denoising + step from time `t` to `s < t` samples every block position from the exact posterior `q(x_s | x_t, x0)`, which for + the uniform kernel decomposes into three routes: jump to the predicted clean token `x0`, stay on the current + token, or jump to a uniformly random token. Unlike masked diffusion, there is no mask token; uncommitted positions + carry random tokens. + + Args: + num_inference_steps (`int`, defaults to 32): + The number of denoising steps, defining the linear time grid the posterior is evaluated on. + """ + + order = 1 + + @register_to_config + def __init__(self, num_inference_steps: int = 32): + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, device=device, dtype=torch.long) + + @staticmethod + def _sample_from_logits( + logits: torch.Tensor, + *, + temperature: float, + generator: torch.Generator | None, + ) -> tuple[torch.LongTensor, torch.Tensor]: + """Sample one token per position with optional temperature, returning tokens and their probabilities.""" + if temperature < 0: + raise ValueError(f"`temperature` must be >= 0, got {temperature}.") + + vocab_size = logits.shape[-1] + flat_logits = logits.reshape(-1, vocab_size) + probs = torch.softmax(flat_logits.float(), dim=-1) + + if temperature == 0.0: + token = flat_logits.argmax(dim=-1, keepdim=True) + else: + scaled_probs = torch.softmax(flat_logits.float() / temperature, dim=-1) + token = torch.multinomial(scaled_probs, num_samples=1, generator=generator) + + token_prob = torch.gather(probs, -1, token) + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + temperature: float = 0.0, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> DiscreteDDIMSchedulerOutput | tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]: + """ + Sample the next block from the posterior `q(x_s | x_t, x0)` of the uniform corruption process. + + With `a = alpha_t / alpha_s` (survival probability from `s` to `t`) and `b = alpha_s`, the posterior mass of + each route is + + clean: `b * (1 - a) / K + a * b * 1[x_t = x0]`, stay: `a * (1 - b) / K`, noise: `(1 - a) * (1 - b) / K`, + + so the last step (`b = 1`) deterministically commits the predicted clean tokens. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`): + Raw logits from the model for the current block. + timestep (`int` or `torch.Tensor`): + Current step index within the denoising schedule, in `[0, num_inference_steps - 1]`. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs `x_t`. + temperature (`float`): + Sampling temperature applied to the logits when drawing `x0`. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_dict (`bool`): + Whether to return a [`DiscreteDDIMSchedulerOutput`] or a plain tuple. + """ + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + sampled_tokens, sampled_probs = self._sample_from_logits( + model_output, temperature=temperature, generator=generator + ) + + vocab_size = model_output.shape[-1] + num_steps = self.num_inference_steps + # `step_index` counts up from 0 to `num_inference_steps - 1`: alpha(t) = 1 - t increases towards the clean end, + # with alpha_s = 1 on the final step so the predicted clean tokens are committed deterministically. + alpha_t = step_index / num_steps + alpha_s = (step_index + 1) / num_steps + survival = alpha_t / alpha_s + + same = (sample == sampled_tokens).float() + clean_mass = alpha_s * (1 - survival) / vocab_size + survival * alpha_s * same + stay_mass = survival * (1 - alpha_s) / vocab_size * torch.ones_like(same) + noise_mass = (1 - survival) * (1 - alpha_s) / vocab_size * torch.ones_like(same) + + route_probs = torch.stack([clean_mass, stay_mass, noise_mass], dim=-1) + route_probs = route_probs / route_probs.sum(dim=-1, keepdim=True) + routes = torch.multinomial(route_probs.view(-1, 3), num_samples=1, generator=generator).view_as(sample) + + random_tokens = torch.randint( + low=0, high=vocab_size, size=sample.shape, device=sample.device, generator=generator + ) + prev_sample = torch.where(routes == 0, sampled_tokens, sample) + prev_sample = torch.where(routes == 2, random_tokens, prev_sample) + + if not return_dict: + return prev_sample, sampled_tokens, sampled_probs + return DiscreteDDIMSchedulerOutput( + prev_sample=prev_sample, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + +__all__ = ["DiscreteDDIMScheduler", "DiscreteDDIMSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_entropy_bound.py b/src/diffusers/schedulers/scheduling_entropy_bound.py new file mode 100644 index 000000000000..fe54eabb403d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_entropy_bound.py @@ -0,0 +1,166 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class EntropyBoundSchedulerOutput(BaseOutput): + """ + Output class for the entropy bound scheduler. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current denoising step. + accepted_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask of the positions accepted (committed) in this step. + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Token IDs sampled from the model logits. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + accepted_index: torch.BoolTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class EntropyBoundScheduler(SchedulerMixin, ConfigMixin): + """ + Entropy bound scheduler for the uniform corruption process. + + At each step the scheduler samples a candidate token per position and accepts the `k` lowest-entropy positions + such that `sum_i^k entropy_i - max(entropy_1, ..., entropy_k) <= entropy_bound`. The left-hand side upper-bounds + the joint mutual information between the accepted tokens, so they are approximately independent. Accepted positions + keep their sampled token; the rest are renoised with uniformly random tokens (there is no mask token). + + Proposed in "Beyond Next-Token Prediction" (https://huggingface.co/papers/2505.24857). + + Args: + entropy_bound (`float`, defaults to 0.1): + The maximum tolerated joint entropy of the accepted tokens. Larger values accept more tokens per step. + num_inference_steps (`int`, defaults to 32): + The maximum number of denoising steps. + """ + + order = 1 + + @register_to_config + def __init__(self, entropy_bound: float = 0.1, num_inference_steps: int = 32): + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(num_inference_steps, device=device, dtype=torch.long) + + @staticmethod + def _sample_from_logits( + logits: torch.Tensor, + *, + temperature: float, + generator: torch.Generator | None, + ) -> tuple[torch.LongTensor, torch.Tensor]: + """Sample one token per position with optional temperature, returning tokens and their probabilities.""" + if temperature < 0: + raise ValueError(f"`temperature` must be >= 0, got {temperature}.") + + vocab_size = logits.shape[-1] + flat_logits = logits.reshape(-1, vocab_size) + probs = torch.softmax(flat_logits.float(), dim=-1) + + if temperature == 0.0: + token = flat_logits.argmax(dim=-1, keepdim=True) + else: + scaled_probs = torch.softmax(flat_logits.float() / temperature, dim=-1) + token = torch.multinomial(scaled_probs, num_samples=1, generator=generator) + + token_prob = torch.gather(probs, -1, token) + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + entropy_bound: float | None = None, + temperature: float = 1.0, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> EntropyBoundSchedulerOutput | tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]: + """ + Accept the lowest-entropy positions under the entropy bound and renoise the rest. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`): + Raw logits from the model for the current block. + timestep (`int` or `torch.Tensor`): + Current step index within the denoising schedule. Unused; kept for API consistency. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs. + entropy_bound (`float`, *optional*): + Overrides the configured entropy bound for this step. + temperature (`float`): + Sampling temperature applied to the logits when drawing the candidate tokens. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_dict (`bool`): + Whether to return an [`EntropyBoundSchedulerOutput`] or a plain tuple. + """ + if entropy_bound is None: + entropy_bound = float(self.config.entropy_bound) + + sampled_tokens, sampled_probs = self._sample_from_logits( + model_output, temperature=temperature, generator=generator + ) + + token_entropy = torch.distributions.Categorical(logits=model_output).entropy() # (batch, block_length) + sorted_token_entropy, sorted_indices = torch.sort(token_entropy, dim=-1, descending=False) + cumulative_entropy = torch.cumsum(sorted_token_entropy, dim=-1) + + # `sorted_token_entropy` is the running maximum entropy (ascending order), so the left-hand side bounds the + # joint mutual information of the accepted tokens. + sorted_accepted = cumulative_entropy - sorted_token_entropy <= entropy_bound + accepted_index = torch.scatter( + input=torch.zeros_like(sorted_accepted), dim=-1, index=sorted_indices, src=sorted_accepted + ) + + random_tokens = torch.randint( + low=0, high=model_output.shape[-1], size=sample.shape, device=sample.device, generator=generator + ) + prev_sample = torch.where(accepted_index, sampled_tokens, random_tokens) + + if not return_dict: + return prev_sample, accepted_index, sampled_tokens, sampled_probs + return EntropyBoundSchedulerOutput( + prev_sample=prev_sample, + accepted_index=accepted_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + +__all__ = ["EntropyBoundScheduler", "EntropyBoundSchedulerOutput"] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8439a2b93371..823123988830 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2882,6 +2882,66 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DiscreteDDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DiscreteDDIMSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EntropyBoundScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EntropyBoundSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/schedulers/test_scheduler_block_refinement.py b/tests/schedulers/test_scheduler_block_refinement.py index 2e5e404e5f9a..673c38890516 100644 --- a/tests/schedulers/test_scheduler_block_refinement.py +++ b/tests/schedulers/test_scheduler_block_refinement.py @@ -466,5 +466,61 @@ def test_negative_temperature_raises(self): ) +class BlockRefinementSchedulerUniformTest(unittest.TestCase): + """Tests for the uniform corruption mode (`mask_token_id=None`), matching DiffusionGemma's block refinement.""" + + def get_scheduler(self, **kwargs): + config = {"block_length": 256, "num_inference_steps": 48, "threshold": 1.0, "editing_threshold": None} + config.update(kwargs) + scheduler = BlockRefinementScheduler(**config) + scheduler.set_timesteps(config["num_inference_steps"], block_length=config["block_length"]) + return scheduler + + def test_cumulative_quota_progression(self): + # threshold=1.0 disables threshold commits, so only the even per-step quota applies: ceil(256/48)=6, then 11. + scheduler = self.get_scheduler() + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + out0 = scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None) + self.assertEqual(scheduler._committed.sum().item(), 6) + scheduler.step(logits, timestep=1, sample=out0.prev_sample, mask_token_id=None) + self.assertEqual(scheduler._committed.sum().item(), 11) + + def test_last_step_commits_all(self): + scheduler = self.get_scheduler() + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None) + scheduler.step(logits, timestep=47, sample=sample, mask_token_id=None) + self.assertTrue(scheduler._committed.all()) + + def test_threshold_commits_beyond_quota(self): + scheduler = self.get_scheduler(threshold=0.5) + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, torch.arange(20), 0] = 1e6 # 20 high-confidence positions (token 0) + scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None, temperature=0.0) + # 20 positions exceed the threshold and get committed regardless of the quota + self.assertEqual(scheduler._committed.sum().item(), 20) + + def test_editing_replaces_committed_token(self): + scheduler = self.get_scheduler(threshold=1.0, editing_threshold=0.5) + sample = torch.zeros(1, 256, dtype=torch.long) + scheduler._committed = torch.ones_like(sample, dtype=torch.bool) # pretend all committed + logits = torch.zeros(1, 256, 10000) + logits[0, 0, 1] = 1e6 # confidently predicts token 1 at position 0 (differs from current token 0) + out = scheduler.step(logits, timestep=24, sample=sample, mask_token_id=None, temperature=0.0) + self.assertEqual(out.prev_sample[0, 0].item(), 1) + self.assertTrue((out.prev_sample[0, 1:] == 0).all()) + + def test_reset_on_new_block(self): + scheduler = self.get_scheduler() + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + scheduler.step(logits, timestep=5, sample=sample, mask_token_id=None) + scheduler.step(logits, timestep=0, sample=sample, mask_token_id=None) # new block resets committed + self.assertEqual(scheduler._committed.sum().item(), 6) + + if __name__ == "__main__": unittest.main() diff --git a/tests/schedulers/test_scheduler_discrete_ddim.py b/tests/schedulers/test_scheduler_discrete_ddim.py new file mode 100644 index 000000000000..8073fbb227b3 --- /dev/null +++ b/tests/schedulers/test_scheduler_discrete_ddim.py @@ -0,0 +1,67 @@ +import unittest + +import torch + +from diffusers import DiscreteDDIMScheduler + + +class DiscreteDDIMSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = {"num_inference_steps": 8} + config.update(kwargs) + return DiscreteDDIMScheduler(**config) + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(16) + self.assertEqual(scheduler.num_inference_steps, 16) + self.assertEqual(len(scheduler.timesteps), 16) + self.assertEqual(scheduler.timesteps[0].item(), 0) + self.assertEqual(scheduler.timesteps[-1].item(), 15) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + + def test_last_step_commits_predicted_tokens(self): + # On the final step alpha_s = 1, so the posterior deterministically commits the sampled clean tokens. + n = 8 + scheduler = self.get_scheduler(num_inference_steps=n) + scheduler.set_timesteps(n) + sample = torch.randint(0, 100, (2, 16)) + logits = torch.zeros(2, 16, 100) + out = scheduler.step(logits, timestep=n - 1, sample=sample, temperature=0.0) + self.assertTrue(torch.equal(out.prev_sample, out.sampled_tokens)) + + def test_intermediate_step_keeps_agreeing_positions(self): + # Where the prediction agrees with the current token, almost all posterior mass is on the clean route. + n = 8 + scheduler = self.get_scheduler(num_inference_steps=n) + scheduler.set_timesteps(n) + sample = torch.randint(0, 100, (1, 256)) + logits = torch.zeros(1, 256, 100) + # argmax of zero logits is token 0; make the sample already equal token 0 everywhere + sample = torch.zeros_like(sample) + out = scheduler.step(logits, timestep=n // 2, sample=sample, temperature=0.0) + kept = (out.prev_sample == sample).sum().item() + self.assertGreaterEqual(kept, 250) + + def test_step_output_shapes(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (3, 16)) + logits = torch.randn(3, 16, 100) + out = scheduler.step(logits, timestep=2, sample=sample, temperature=1.0) + self.assertEqual(out.prev_sample.shape, sample.shape) + self.assertEqual(out.sampled_tokens.shape, sample.shape) + self.assertEqual(out.sampled_probs.shape, sample.shape) + + def test_return_tuple(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (1, 16)) + logits = torch.randn(1, 16, 100) + out = scheduler.step(logits, timestep=2, sample=sample, return_dict=False) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 3) diff --git a/tests/schedulers/test_scheduler_entropy_bound.py b/tests/schedulers/test_scheduler_entropy_bound.py new file mode 100644 index 000000000000..c5bd8560d57d --- /dev/null +++ b/tests/schedulers/test_scheduler_entropy_bound.py @@ -0,0 +1,56 @@ +import unittest + +import torch + +from diffusers import EntropyBoundScheduler + + +class EntropyBoundSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = {"entropy_bound": 0.1, "num_inference_steps": 8} + config.update(kwargs) + return EntropyBoundScheduler(**config) + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(16) + self.assertEqual(scheduler.num_inference_steps, 16) + self.assertEqual(len(scheduler.timesteps), 16) + + def test_zero_entropy_positions_accepted(self): + # Positions with a near-one probability have ~zero entropy and must be accepted. + scheduler = self.get_scheduler(entropy_bound=0.1) + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, :9, 0] = 1e6 # 9 zero-entropy positions + out = scheduler.step(logits, timestep=0, sample=sample, temperature=0.0) + self.assertGreaterEqual(out.accepted_index.sum().item(), 9) + # accepted positions hold the sampled token (token 0 here) + self.assertTrue((out.prev_sample[0, :9] == 0).all()) + + def test_higher_bound_accepts_at_least_as_many(self): + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, 0, 0] = 1.8e1 + logits[0, 1, 1] = 1.45e1 + logits[0, 2, 2] = 1.45e1 + low = self.get_scheduler(entropy_bound=1e-2).step(logits, 0, sample, temperature=0.0) + high = self.get_scheduler(entropy_bound=1e-1).step(logits, 0, sample, temperature=0.0) + self.assertGreaterEqual(high.accepted_index.sum().item(), low.accepted_index.sum().item()) + + def test_non_accepted_are_renoised(self): + scheduler = self.get_scheduler(entropy_bound=1e-3) + sample = torch.randint(0, 10000, (1, 256)) + logits = torch.zeros(1, 256, 10000) + logits[0, :5, 0] = 1e6 + out = scheduler.step(logits, timestep=0, sample=sample, temperature=0.0) + # the 5 accepted positions hold token 0, the rest are random (not token 0 almost surely) + self.assertTrue((out.prev_sample[0, :5] == 0).all()) + + def test_step_output_shapes(self): + scheduler = self.get_scheduler() + sample = torch.randint(0, 100, (3, 16)) + logits = torch.randn(3, 16, 100) + out = scheduler.step(logits, timestep=0, sample=sample, temperature=1.0) + self.assertEqual(out.prev_sample.shape, sample.shape) + self.assertEqual(out.accepted_index.shape, sample.shape) From 6168e6d66e4b7e99ecff4eadea44e3351541d30e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 13:15:43 +0200 Subject: [PATCH 02/17] Add DiffusionGemma block-diffusion pipeline --- src/diffusers/__init__.py | 4 + src/diffusers/pipelines/__init__.py | 2 + .../pipelines/diffusion_gemma/__init__.py | 47 +++ .../pipeline_diffusion_gemma.py | 366 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 30 ++ 5 files changed, 449 insertions(+) create mode 100644 src/diffusers/pipelines/diffusion_gemma/__init__.py create mode 100644 src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e5b36731b5ba..a9a01d1f05e7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -576,6 +576,8 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DiffusionGemmaPipeline", + "DiffusionGemmaPipelineOutput", "DreamLiteMobilePipeline", "DreamLitePipeline", "DreamLitePipelineOutput", @@ -1414,6 +1416,8 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DiffusionGemmaPipeline, + DiffusionGemmaPipelineOutput, DreamLiteMobilePipeline, DreamLitePipeline, DreamLitePipelineOutput, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index caec1aee30e7..c9536eca334d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -273,6 +273,7 @@ "IFPipeline", "IFSuperResolutionPipeline", ] + _import_structure["diffusion_gemma"] = ["DiffusionGemmaPipeline", "DiffusionGemmaPipelineOutput"] _import_structure["dreamlite"] = ["DreamLitePipeline", "DreamLiteMobilePipeline", "DreamLitePipelineOutput"] _import_structure["easyanimate"] = [ "EasyAnimatePipeline", @@ -716,6 +717,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) + from .diffusion_gemma import DiffusionGemmaPipeline, DiffusionGemmaPipelineOutput from .dreamlite import ( DreamLiteMobilePipeline, DreamLitePipeline, diff --git a/src/diffusers/pipelines/diffusion_gemma/__init__.py b/src/diffusers/pipelines/diffusion_gemma/__init__.py new file mode 100644 index 000000000000..282efe37a797 --- /dev/null +++ b/src/diffusers/pipelines/diffusion_gemma/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_diffusion_gemma"] = ["DiffusionGemmaPipeline", "DiffusionGemmaPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_diffusion_gemma import DiffusionGemmaPipeline, DiffusionGemmaPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py new file mode 100644 index 000000000000..e4f97f03c6ff --- /dev/null +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -0,0 +1,366 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from tqdm.auto import tqdm + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import BlockRefinementScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion + >>> from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline + + >>> model_id = "google/diffusiongemma-26B-A4B-it" + >>> model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> scheduler = BlockRefinementScheduler() + + >>> pipe = DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor) + >>> output = pipe(prompt="Why is the sky blue?", gen_length=256) + >>> print(output.texts[0]) + ``` +""" + + +@dataclass +class DiffusionGemmaPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +class DiffusionGemmaPipeline(DiffusionPipeline): + r""" + Pipeline for DiffusionGemma block-diffusion text generation. + + DiffusionGemma is a block-diffusion encoder-decoder model: a causal encoder reads the clean prompt (and any + previously generated blocks) into a KV cache, and a bidirectional decoder denoises a fixed-size "canvas" of + `canvas_length` tokens by cross-attending to that cache. Generation alternates an outer autoregressive loop over + canvases with an inner denoising loop, where each step samples candidate tokens, commits the most confident ones + via [`BlockRefinementScheduler`] (uniform corruption mode, `mask_token_id=None`), and renoises the rest. + + The model is expected to be a `DiffusionGemmaForBlockDiffusion` instance exposing `forward(input_ids, + decoder_input_ids=..., self_conditioning_logits=..., ...)` and returning logits of shape `[batch, canvas_length, + vocab_size]` over the canvas. + """ + + model: Any + scheduler: BlockRefinementScheduler + processor: Any + + _callback_tensor_inputs = ["canvas", "logits"] + + def __init__( + self, + model: Any, + scheduler: BlockRefinementScheduler, + processor: Any | None = None, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, processor=processor) + text_config = model.config.get_text_config() + self.canvas_length = model.config.canvas_length + self.vocab_size = text_config.vocab_size + tokenizer = getattr(processor, "tokenizer", processor) + self.eos_token_id = getattr(tokenizer, "eos_token_id", None) if tokenizer is not None else None + + @property + def num_timesteps(self): + return self._num_timesteps + + # --- Prompt encoding --- + + def _prepare_input_ids( + self, + *, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + attention_mask: torch.LongTensor | None, + add_generation_prompt: bool, + ) -> tuple[torch.LongTensor, torch.LongTensor]: + """Convert prompt/messages/input_ids to `(input_ids, attention_mask)` tensors of shape `[batch, seq]`.""" + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + elif attention_mask.ndim == 1: + attention_mask = attention_mask.unsqueeze(0) + return input_ids, attention_mask.to(dtype=torch.long) + + if self.processor is None: + raise ValueError("`processor` is required when `input_ids` is not provided.") + + if messages is None: + if isinstance(prompt, list): + messages = [[{"role": "user", "content": p}] for p in prompt] + else: + messages = [{"role": "user", "content": prompt}] + + encoded = self.processor.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + ) + ids = encoded["input_ids"] + mask = encoded.get("attention_mask") + if mask is None: + mask = torch.ones_like(ids, dtype=torch.long) + return ids, mask.to(dtype=torch.long) + + def check_inputs( + self, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + gen_length: int, + num_inference_steps: int, + output_type: str, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + if prompt is None and messages is None and input_ids is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if (prompt is not None or messages is not None) and input_ids is None and self.processor is None: + raise ValueError("`processor` is required when `input_ids` is not provided.") + if gen_length <= 0: + raise ValueError(f"`gen_length` must be > 0, got {gen_length}.") + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + messages: list[dict[str, str]] | None = None, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.LongTensor | None = None, + add_generation_prompt: bool = True, + gen_length: int = 256, + num_inference_steps: int = 32, + temperature: float = 0.0, + top_p: float | None = None, + top_k: int | None = None, + threshold: float | None = None, + editing_threshold: float | None = None, + eos_early_stop: bool = True, + eos_token_id: int | None = None, + generator: torch.Generator | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[Any, int, int, dict], dict] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> DiffusionGemmaPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text with block diffusion. + + Args: + prompt (`str` or `List[str]`, *optional*): + Prompt text, wrapped in a chat template and tokenized by the processor. + messages (`List[Dict[str, str]]`, *optional*): + Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over + `prompt`. Requires a processor with `apply_chat_template`. + input_ids (`torch.LongTensor`, *optional*): + Pre-tokenized prompt IDs. Takes precedence over `prompt` and `messages`. + attention_mask (`torch.LongTensor`, *optional*): + Per-token mask matching `input_ids`. Only used when `input_ids` is provided. + add_generation_prompt (`bool`, defaults to `True`): + Whether to add the generation prompt when applying the chat template. + gen_length (`int`, defaults to `256`): + Number of tokens to generate, rounded up to a multiple of the model's `canvas_length`. + num_inference_steps (`int`, defaults to `32`): + Number of denoising steps per canvas. + temperature (`float`, defaults to `0.0`): + Sampling temperature. `0.0` is greedy. + top_p (`float`, *optional*): + Nucleus sampling cutoff. + top_k (`int`, *optional*): + Top-k sampling cutoff. + threshold (`float`, *optional*): + Confidence threshold for committing tokens. Defaults to the scheduler's configured value. + editing_threshold (`float`, *optional*): + Confidence threshold for re-editing already committed tokens. Defaults to the scheduler's value. + eos_early_stop (`bool`, defaults to `True`): + Whether to stop generating further canvases once every sequence has emitted EOS. + eos_token_id (`int`, *optional*): + EOS token ID for early stopping. Falls back to the processor's tokenizer. + generator (`torch.Generator`, *optional*): + RNG for sampling. + output_type (`str`, defaults to `"text"`): + `"text"` decodes sequences into strings (requires a processor); `"seq"` returns token IDs only. + return_dict (`bool`, defaults to `True`): + Whether to return a [`DiffusionGemmaPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback run after each denoising step with signature `callback_on_step_end(self, step, timestep, + callback_kwargs)`. Allowed tensor keys: `canvas`, `logits`. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor keys to pass to the callback. + + Examples: + + Returns: + [`~pipelines.diffusion_gemma.pipeline_diffusion_gemma.DiffusionGemmaPipelineOutput`] or `tuple`: + The generated token IDs (`sequences`) and, for `output_type="text"`, the decoded `texts`. + """ + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["canvas"] + + self.check_inputs( + prompt=prompt, + messages=messages, + input_ids=input_ids, + gen_length=gen_length, + num_inference_steps=num_inference_steps, + output_type=output_type, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + prompt_ids, prompt_attention_mask = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + attention_mask=attention_mask, + add_generation_prompt=add_generation_prompt, + ) + + device = self._execution_device + prompt_ids = prompt_ids.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + batch_size, prompt_length = prompt_ids.shape + + if eos_token_id is None: + eos_token_id = self.eos_token_id + + canvas_length = self.canvas_length + num_canvases = (gen_length + canvas_length - 1) // canvas_length + self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=canvas_length) + self._num_timesteps = num_inference_steps * num_canvases + + cur_input_ids = prompt_ids + cur_attention_mask = prompt_attention_mask + finished = torch.zeros(batch_size, dtype=torch.bool, device=device) + global_step = 0 + + progress_bar = tqdm(range(num_canvases), **getattr(self, "_progress_bar_config", {})) + for _ in progress_bar: + cur_len = cur_input_ids.shape[1] + decoder_position_ids = torch.arange(cur_len, cur_len + canvas_length, device=device).unsqueeze(0) + decoder_attention_mask = F.pad(cur_attention_mask.bool(), (0, canvas_length), value=True) + + # Start from a fully random canvas and denoise it; the scheduler resets its committed state at step 0. + canvas = torch.randint(0, self.vocab_size, (batch_size, canvas_length), device=device, generator=generator) + self_conditioning_logits = None + + for step_idx in range(num_inference_steps): + outputs = self.model( + input_ids=cur_input_ids, + attention_mask=cur_attention_mask, + decoder_input_ids=canvas, + decoder_position_ids=decoder_position_ids, + decoder_attention_mask=decoder_attention_mask, + self_conditioning_logits=self_conditioning_logits, + ) + logits = outputs.logits + self_conditioning_logits = logits + + scheduler_output = self.scheduler.step( + model_output=logits, + timestep=step_idx, + sample=canvas, + mask_token_id=None, + temperature=temperature, + top_p=top_p, + top_k=top_k, + threshold=threshold, + editing_threshold=editing_threshold, + generator=generator, + return_dict=True, + ) + canvas = scheduler_output.prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) + canvas = callback_outputs.pop("canvas", canvas) + global_step += 1 + + # Append the denoised canvas and extend the context for the next block. + cur_input_ids = torch.cat([cur_input_ids, canvas], dim=-1) + cur_attention_mask = F.pad(cur_attention_mask, (0, canvas_length), value=1) + + if eos_early_stop and eos_token_id is not None: + finished = finished | (canvas == eos_token_id).any(dim=-1) + if finished.all(): + break + + progress_bar.close() + + sequences = cur_input_ids[:, prompt_length:] + + # Trim each row at its first EOS so post-EOS canvas tokens don't leak into the decoded text. + decode_sequences: list[torch.LongTensor] | torch.LongTensor = sequences + if eos_token_id is not None: + decode_sequences = [ + seq[: int((seq == eos_token_id).nonzero(as_tuple=True)[0][0]) + 1] + if (seq == eos_token_id).any() + else seq + for seq in sequences + ] + + texts = None + if output_type == "text" and self.processor is not None: + texts = self.processor.batch_decode(decode_sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return DiffusionGemmaPipelineOutput(sequences=sequences, texts=texts) + + +__all__ = ["DiffusionGemmaPipeline", "DiffusionGemmaPipelineOutput"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0786186dff53..ecc3da616248 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1457,6 +1457,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class DiffusionGemmaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DiffusionGemmaPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class DreamLiteMobilePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 375d63a193b67a19aa5b8c930540fd94245dd328 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 13:15:44 +0200 Subject: [PATCH 03/17] Add DiffusionGemma pipeline tests and docs --- docs/source/en/_toctree.yml | 2 + .../en/api/pipelines/diffusion_gemma.md | 71 ++++++++ tests/pipelines/diffusion_gemma/__init__.py | 0 .../diffusion_gemma/test_diffusion_gemma.py | 171 ++++++++++++++++++ 4 files changed, 244 insertions(+) create mode 100644 docs/source/en/api/pipelines/diffusion_gemma.md create mode 100644 tests/pipelines/diffusion_gemma/__init__.py create mode 100644 tests/pipelines/diffusion_gemma/test_diffusion_gemma.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b1b7ffebb780..27547c2fc546 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -525,6 +525,8 @@ title: DDPM - local: api/pipelines/deepfloyd_if title: DeepFloyd IF + - local: api/pipelines/diffusion_gemma + title: DiffusionGemma - local: api/pipelines/dit title: DiT - local: api/pipelines/dreamlite diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md new file mode 100644 index 000000000000..b7094244bd0e --- /dev/null +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -0,0 +1,71 @@ + + +# DiffusionGemma + +DiffusionGemma is a block-diffusion encoder-decoder language model. A causal encoder reads the clean prompt (and any +previously generated blocks) into a KV cache, and a bidirectional decoder denoises a fixed-size "canvas" of +`canvas_length` tokens by cross-attending to that cache. Generation alternates an outer autoregressive loop over +canvases with an inner denoising loop, where each step samples candidate tokens, commits the most confident ones via +[`BlockRefinementScheduler`] in uniform corruption mode, and renoises the rest. The model itself lives in +`transformers` as `DiffusionGemmaForBlockDiffusion`. + +## Usage + +```py +import torch +from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion + +from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline + +model_id = "google/diffusiongemma-26B-A4B-it" +model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id) +scheduler = BlockRefinementScheduler() + +pipe = DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor) +output = pipe( + prompt="Why is the sky blue?", + gen_length=256, + num_inference_steps=32, + temperature=0.0, +) +print(output.texts[0]) +``` + +## Callbacks + +Callbacks run after each denoising step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are +included in `callback_kwargs`; `canvas` (the current block tokens) and `logits` are available. Return `{"canvas": ...}` +from the callback to replace the canvas. + +```py +def on_step_end(pipe, step, timestep, callback_kwargs): + canvas = callback_kwargs["canvas"] + # Inspect or modify `canvas` here. + return {"canvas": canvas} + + +out = pipe( + prompt="Why is the sky blue?", + callback_on_step_end=on_step_end, + callback_on_step_end_tensor_inputs=["canvas"], +) +``` + +## DiffusionGemmaPipeline +[[autodoc]] DiffusionGemmaPipeline + - all + - __call__ + +## DiffusionGemmaPipelineOutput +[[autodoc]] pipelines.DiffusionGemmaPipelineOutput diff --git a/tests/pipelines/diffusion_gemma/__init__.py b/tests/pipelines/diffusion_gemma/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py new file mode 100644 index 000000000000..0d4b399d6377 --- /dev/null +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -0,0 +1,171 @@ +import unittest + +import torch + +from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyTextConfig: + def __init__(self, vocab_size: int): + self.vocab_size = int(vocab_size) + self.eos_token_id = None + + +class _DummyConfig: + def __init__(self, canvas_length: int, vocab_size: int): + self.canvas_length = int(canvas_length) + self._text_config = _DummyTextConfig(vocab_size) + + def get_text_config(self): + return self._text_config + + +class _DummyBlockDiffusionModel(torch.nn.Module): + """Stand-in for `DiffusionGemmaForBlockDiffusion`: returns logits over the decoder canvas.""" + + def __init__(self, vocab_size: int = 32, canvas_length: int = 8): + super().__init__() + self.vocab_size = int(vocab_size) + self.config = _DummyConfig(canvas_length, vocab_size) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids=None, decoder_input_ids=None, **kwargs): + batch_size, canvas_len = decoder_input_ids.shape + device = decoder_input_ids.device + logits = torch.zeros((batch_size, canvas_len, self.vocab_size), device=device, dtype=torch.float32) + # Make confidence vary with canvas position so the commit quota is deterministic. + positions = torch.arange(canvas_len, device=device, dtype=torch.float32).view(1, canvas_len, 1) + token_ids = (torch.arange(canvas_len, device=device) % (self.vocab_size - 2)).view(1, canvas_len, 1) + logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1) + return _DummyModelOutput(logits=logits) + + +def _make_pipeline(processor=None, canvas_length: int = 8): + model = _DummyBlockDiffusionModel(vocab_size=32, canvas_length=canvas_length) + scheduler = BlockRefinementScheduler() + return DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor) + + +class DiffusionGemmaPipelineTest(unittest.TestCase): + def test_pipeline_runs(self): + pipe = _make_pipeline().to("cpu") + input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long) + out = pipe( + input_ids=input_ids, + gen_length=24, # 3 canvases of length 8 + num_inference_steps=8, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (2, 24)) + self.assertIsNone(out.texts) + + def test_pipeline_return_tuple(self): + pipe = _make_pipeline().to("cpu") + sequences, texts = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=16, + num_inference_steps=4, + eos_early_stop=False, + output_type="seq", + return_dict=False, + ) + self.assertEqual(sequences.shape, (1, 16)) + self.assertIsNone(texts) + + def test_output_type_text_with_processor(self): + processor = type( + "Proc", + (), + { + "tokenizer": type("Tok", (), {"eos_token_id": None})(), + "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], + }, + )() + pipe = _make_pipeline(processor=processor).to("cpu") + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=16, + num_inference_steps=4, + eos_early_stop=False, + output_type="text", + ) + self.assertIsNotNone(out.texts) + self.assertEqual(len(out.texts), 1) + self.assertTrue(out.texts[0].startswith("decoded_")) + + def test_output_type_invalid_raises(self): + pipe = _make_pipeline().to("cpu") + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=2, + output_type="invalid", + ) + + def test_no_inputs_raises(self): + pipe = _make_pipeline().to("cpu") + with self.assertRaises(ValueError): + pipe(gen_length=8, num_inference_steps=2, output_type="seq") + + def test_prepare_input_ids_from_1d_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([1, 2, 3], dtype=torch.long) + result_ids, result_mask = pipe._prepare_input_ids( + prompt=None, messages=None, input_ids=ids, attention_mask=None, add_generation_prompt=False + ) + self.assertEqual(result_ids.shape, (1, 3)) + self.assertEqual(result_mask.shape, (1, 3)) + self.assertTrue((result_mask == 1).all().item()) + + def test_callback_receives_advertised_keys(self): + observed: list[str] = [] + + def cb(pipe, step, timestep, kwargs): + observed.extend(sorted(kwargs.keys())) + return {} + + pipe = _make_pipeline().to("cpu") + keys = list(pipe._callback_tensor_inputs) + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + eos_early_stop=False, + output_type="seq", + callback_on_step_end=cb, + callback_on_step_end_tensor_inputs=keys, + ) + self.assertEqual(set(observed), set(keys)) + + def test_progress_bar_disable_is_preserved_after_call(self): + pipe = _make_pipeline().to("cpu") + pipe.set_progress_bar_config(disable=True) + before = dict(pipe._progress_bar_config) + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=2, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(pipe._progress_bar_config, before) + + +if __name__ == "__main__": + unittest.main() From 245a6efac971622eccb959ecc216e1d23162f6dd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 15:20:21 +0200 Subject: [PATCH 04/17] Put DiffusionGemma docs under the Text pipelines section --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b85b09dc835b..9347e878c3ba 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -525,8 +525,6 @@ title: DDPM - local: api/pipelines/deepfloyd_if title: DeepFloyd IF - - local: api/pipelines/diffusion_gemma - title: DiffusionGemma - local: api/pipelines/dit title: DiT - local: api/pipelines/dreamlite @@ -645,6 +643,8 @@ title: Z-Image title: Image - sections: + - local: api/pipelines/diffusion_gemma + title: DiffusionGemma - local: api/pipelines/llada2 title: LLaDA2 title: Text From aae6e0298260b9eb23c2dfdbadd6e7cfbddd6e42 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 15:50:59 +0200 Subject: [PATCH 05/17] Add static cache and fullgraph-compiled decoder path to DiffusionGemma pipeline --- .../pipeline_diffusion_gemma.py | 81 ++++++++++++++++--- .../diffusion_gemma/test_diffusion_gemma.py | 35 ++++++++ 2 files changed, 106 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index e4f97f03c6ff..6a1240aad9cc 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F from tqdm.auto import tqdm +from transformers import StaticCache from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...schedulers import BlockRefinementScheduler @@ -183,6 +184,8 @@ def __call__( top_k: int | None = None, threshold: float | None = None, editing_threshold: float | None = None, + cache_implementation: str | None = None, + compile_decoder: bool = False, eos_early_stop: bool = True, eos_token_id: int | None = None, generator: torch.Generator | None = None, @@ -223,6 +226,13 @@ def __call__( Confidence threshold for committing tokens. Defaults to the scheduler's configured value. editing_threshold (`float`, *optional*): Confidence threshold for re-editing already committed tokens. Defaults to the scheduler's value. + cache_implementation (`str`, *optional*): + Set to `"static"` to prefill the encoder once per block into a persistent `StaticCache` and run the + decoder against it with fixed shapes, instead of re-encoding the full sequence on every step. Required + for `compile_decoder`. + compile_decoder (`bool`, defaults to `False`): + Whether to `torch.compile(fullgraph=True)` the decoder forward. Only takes effect with + `cache_implementation="static"`, whose fixed shapes make the decoder graph-break free. eos_early_stop (`bool`, defaults to `True`): Whether to stop generating further canvases once every sequence has emitted EOS. eos_token_id (`int`, *optional*): @@ -288,26 +298,77 @@ def __call__( finished = torch.zeros(batch_size, dtype=torch.bool, device=device) global_step = 0 + # With `cache_implementation="static"` the encoder prefills a persistent `StaticCache` once per block and the + # decoder runs with fixed shapes against it (instead of re-encoding the full sequence every step), which also + # lets the decoder forward be `torch.compile(fullgraph=True)`-d when `compile_decoder=True`. + use_static_cache = cache_implementation == "static" + past_key_values = None + decoder_logits = None + if use_static_cache: + text_config = self.model.config.get_text_config(decoder=True) + max_cache_len = prompt_length + num_canvases * canvas_length + past_key_values = StaticCache(config=text_config, max_cache_len=max_cache_len) + + def decoder_logits(canvas, self_cond, mask_mapping, dec_pos): + return self.model( + decoder_input_ids=canvas, + past_key_values=past_key_values, + self_conditioning_logits=self_cond, + decoder_attention_mask=mask_mapping, + decoder_position_ids=dec_pos, + ).logits + + if compile_decoder: + decoder_logits = torch.compile(decoder_logits, fullgraph=True) + progress_bar = tqdm(range(num_canvases), **getattr(self, "_progress_bar_config", {})) for _ in progress_bar: cur_len = cur_input_ids.shape[1] decoder_position_ids = torch.arange(cur_len, cur_len + canvas_length, device=device).unsqueeze(0) - decoder_attention_mask = F.pad(cur_attention_mask.bool(), (0, canvas_length), value=True) + + mask_mapping = None + if use_static_cache: + # Encode the tokens not yet in the cache (the whole prompt on the first block, the last committed + # canvas afterwards), then build the fixed-size 4D decoder mask once for this block (outside any + # compiled region, so the compiled decoder never constructs masks). + cached_len = past_key_values.get_seq_length() + new_tokens = cur_input_ids[:, cached_len:] + self.model.model.encoder( + input_ids=new_tokens, + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + position_ids=torch.arange(cached_len, cur_len, device=device).unsqueeze(0), + ) + decoder_attention_mask = torch.zeros( + (batch_size, max_cache_len + canvas_length), dtype=torch.bool, device=device + ) + decoder_attention_mask[:, :cur_len] = cur_attention_mask.bool() + decoder_attention_mask[:, -canvas_length:] = True + mask_mapping = self.model.model.decoder.create_diffusion_decoder_attention_mask( + config=text_config, + inputs_embeds=torch.empty((batch_size, canvas_length, 0), device=device), + past_key_values=past_key_values, + decoder_attention_mask=decoder_attention_mask, + ) + else: + decoder_attention_mask = F.pad(cur_attention_mask.bool(), (0, canvas_length), value=True) # Start from a fully random canvas and denoise it; the scheduler resets its committed state at step 0. canvas = torch.randint(0, self.vocab_size, (batch_size, canvas_length), device=device, generator=generator) self_conditioning_logits = None for step_idx in range(num_inference_steps): - outputs = self.model( - input_ids=cur_input_ids, - attention_mask=cur_attention_mask, - decoder_input_ids=canvas, - decoder_position_ids=decoder_position_ids, - decoder_attention_mask=decoder_attention_mask, - self_conditioning_logits=self_conditioning_logits, - ) - logits = outputs.logits + if use_static_cache: + logits = decoder_logits(canvas, self_conditioning_logits, mask_mapping, decoder_position_ids) + else: + logits = self.model( + input_ids=cur_input_ids, + attention_mask=cur_attention_mask, + decoder_input_ids=canvas, + decoder_position_ids=decoder_position_ids, + decoder_attention_mask=decoder_attention_mask, + self_conditioning_logits=self_conditioning_logits, + ).logits self_conditioning_logits = logits scheduler_output = self.scheduler.step( diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index 0d4b399d6377..f2a13bf07772 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -167,5 +167,40 @@ def test_progress_bar_disable_is_preserved_after_call(self): self.assertEqual(pipe._progress_bar_config, before) +class DiffusionGemmaStaticCacheTest(unittest.TestCase): + """The static-cache path uses the real model internals (encoder prefill + `StaticCache`), so it needs the tiny + checkpoint rather than a stand-in. Skips when the model can't be fetched (e.g. offline CI).""" + + def _load_pipeline(self): + try: + from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion + except ImportError as e: + self.skipTest(f"transformers without DiffusionGemma: {e}") + model_id = "trl-internal-testing/tiny-DiffusionGemmaForBlockDiffusion" + try: + model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=torch.float32).eval() + processor = AutoProcessor.from_pretrained(model_id) + except Exception as e: # noqa: BLE001 - offline / hub errors should skip, not fail + self.skipTest(f"tiny DiffusionGemma checkpoint unavailable: {e}") + pipe = DiffusionGemmaPipeline(model=model, scheduler=BlockRefinementScheduler(), processor=processor) + pipe.set_progress_bar_config(disable=True) + return pipe, model.config.canvas_length + + def test_static_cache_matches_dynamic(self): + pipe, canvas_length = self._load_pipeline() + kwargs = dict( + messages=[{"role": "user", "content": "Name a color."}], + gen_length=canvas_length * 2, # two canvases -> exercises the cache extension between blocks + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + dynamic = pipe(generator=torch.Generator().manual_seed(0), **kwargs).sequences + static = pipe(generator=torch.Generator().manual_seed(0), cache_implementation="static", **kwargs).sequences + self.assertEqual(dynamic.shape, (1, canvas_length * 2)) + self.assertTrue(torch.equal(dynamic, static)) + + if __name__ == "__main__": unittest.main() From d6018817c382a4a51e34d0f3dc09303618c23680 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 16:09:00 +0200 Subject: [PATCH 06/17] Compile decoder externally for the static cache path instead of a pipeline flag --- .../en/api/pipelines/diffusion_gemma.md | 11 ++++++ .../pipeline_diffusion_gemma.py | 36 ++++++++----------- .../diffusion_gemma/test_diffusion_gemma.py | 16 ++++----- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md index b7094244bd0e..6061844b25b4 100644 --- a/docs/source/en/api/pipelines/diffusion_gemma.md +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -42,6 +42,17 @@ output = pipe( print(output.texts[0]) ``` +## Static cache and compilation + +By default the pipeline re-encodes the prompt on every denoising step. Pass `cache_implementation="static"` to instead +prefill the encoder once per block into a persistent `StaticCache` and run the decoder against it with fixed shapes. +The fixed shapes let you `torch.compile` the decoder for a further speedup: + +```py +pipe.model.model.decoder = torch.compile(pipe.model.model.decoder, fullgraph=True) +output = pipe(prompt="Why is the sky blue?", gen_length=256, cache_implementation="static") +``` + ## Callbacks Callbacks run after each denoising step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index 6a1240aad9cc..39ddb620c253 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -185,7 +185,6 @@ def __call__( threshold: float | None = None, editing_threshold: float | None = None, cache_implementation: str | None = None, - compile_decoder: bool = False, eos_early_stop: bool = True, eos_token_id: int | None = None, generator: torch.Generator | None = None, @@ -228,11 +227,9 @@ def __call__( Confidence threshold for re-editing already committed tokens. Defaults to the scheduler's value. cache_implementation (`str`, *optional*): Set to `"static"` to prefill the encoder once per block into a persistent `StaticCache` and run the - decoder against it with fixed shapes, instead of re-encoding the full sequence on every step. Required - for `compile_decoder`. - compile_decoder (`bool`, defaults to `False`): - Whether to `torch.compile(fullgraph=True)` the decoder forward. Only takes effect with - `cache_implementation="static"`, whose fixed shapes make the decoder graph-break free. + decoder against it with fixed shapes, instead of re-encoding the full sequence on every step. The + fixed shapes also let you compile the decoder, e.g. + `pipe.model.model.decoder = torch.compile(pipe.model.model.decoder, fullgraph=True)`. eos_early_stop (`bool`, defaults to `True`): Whether to stop generating further canvases once every sequence has emitted EOS. eos_token_id (`int`, *optional*): @@ -299,28 +296,17 @@ def __call__( global_step = 0 # With `cache_implementation="static"` the encoder prefills a persistent `StaticCache` once per block and the - # decoder runs with fixed shapes against it (instead of re-encoding the full sequence every step), which also - # lets the decoder forward be `torch.compile(fullgraph=True)`-d when `compile_decoder=True`. + # decoder runs with fixed shapes against it (instead of re-encoding the full sequence every step). The fixed + # shapes also let a user `torch.compile` the decoder module fullgraph. use_static_cache = cache_implementation == "static" past_key_values = None - decoder_logits = None + text_config = None + max_cache_len = None if use_static_cache: text_config = self.model.config.get_text_config(decoder=True) max_cache_len = prompt_length + num_canvases * canvas_length past_key_values = StaticCache(config=text_config, max_cache_len=max_cache_len) - def decoder_logits(canvas, self_cond, mask_mapping, dec_pos): - return self.model( - decoder_input_ids=canvas, - past_key_values=past_key_values, - self_conditioning_logits=self_cond, - decoder_attention_mask=mask_mapping, - decoder_position_ids=dec_pos, - ).logits - - if compile_decoder: - decoder_logits = torch.compile(decoder_logits, fullgraph=True) - progress_bar = tqdm(range(num_canvases), **getattr(self, "_progress_bar_config", {})) for _ in progress_bar: cur_len = cur_input_ids.shape[1] @@ -359,7 +345,13 @@ def decoder_logits(canvas, self_cond, mask_mapping, dec_pos): for step_idx in range(num_inference_steps): if use_static_cache: - logits = decoder_logits(canvas, self_conditioning_logits, mask_mapping, decoder_position_ids) + logits = self.model( + decoder_input_ids=canvas, + past_key_values=past_key_values, + self_conditioning_logits=self_conditioning_logits, + decoder_attention_mask=mask_mapping, + decoder_position_ids=decoder_position_ids, + ).logits else: logits = self.model( input_ids=cur_input_ids, diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index f2a13bf07772..e73595982757 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -188,14 +188,14 @@ def _load_pipeline(self): def test_static_cache_matches_dynamic(self): pipe, canvas_length = self._load_pipeline() - kwargs = dict( - messages=[{"role": "user", "content": "Name a color."}], - gen_length=canvas_length * 2, # two canvases -> exercises the cache extension between blocks - num_inference_steps=4, - temperature=0.0, - eos_early_stop=False, - output_type="seq", - ) + kwargs = { + "messages": [{"role": "user", "content": "Name a color."}], + "gen_length": canvas_length * 2, # two canvases -> exercises the cache extension between blocks + "num_inference_steps": 4, + "temperature": 0.0, + "eos_early_stop": False, + "output_type": "seq", + } dynamic = pipe(generator=torch.Generator().manual_seed(0), **kwargs).sequences static = pipe(generator=torch.Generator().manual_seed(0), cache_implementation="static", **kwargs).sequences self.assertEqual(dynamic.shape, (1, canvas_length * 2)) From 93517817bd6e79afa77fb558d2d6c1acc62b51be Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 16:45:16 +0200 Subject: [PATCH 07/17] Prefill the encoder once into a reusable cache and sync default denoising steps --- .../pipeline_diffusion_gemma.py | 95 ++++---- .../diffusion_gemma/test_diffusion_gemma.py | 219 +++++++----------- 2 files changed, 124 insertions(+), 190 deletions(-) diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index 39ddb620c253..ee9793d5b1c0 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -20,7 +20,7 @@ import torch import torch.nn.functional as F from tqdm.auto import tqdm -from transformers import StaticCache +from transformers import DynamicCache, StaticCache from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...schedulers import BlockRefinementScheduler @@ -178,7 +178,7 @@ def __call__( attention_mask: torch.LongTensor | None = None, add_generation_prompt: bool = True, gen_length: int = 256, - num_inference_steps: int = 32, + num_inference_steps: int = 48, temperature: float = 0.0, top_p: float | None = None, top_k: int | None = None, @@ -213,7 +213,7 @@ def __call__( Whether to add the generation prompt when applying the chat template. gen_length (`int`, defaults to `256`): Number of tokens to generate, rounded up to a multiple of the model's `canvas_length`. - num_inference_steps (`int`, defaults to `32`): + num_inference_steps (`int`, defaults to `48`): Number of denoising steps per canvas. temperature (`float`, defaults to `0.0`): Sampling temperature. `0.0` is greedy. @@ -295,72 +295,59 @@ def __call__( finished = torch.zeros(batch_size, dtype=torch.bool, device=device) global_step = 0 - # With `cache_implementation="static"` the encoder prefills a persistent `StaticCache` once per block and the - # decoder runs with fixed shapes against it (instead of re-encoding the full sequence every step). The fixed - # shapes also let a user `torch.compile` the decoder module fullgraph. + # Encode each block of context once into a reusable KV cache and run the decoder against it, rather than + # re-encoding the whole sequence on every denoising step. The default `DynamicCache` grows with the context; + # `cache_implementation="static"` uses a fixed-shape `StaticCache` so the decoder can be `torch.compile`-d. use_static_cache = cache_implementation == "static" - past_key_values = None - text_config = None - max_cache_len = None + text_config = self.model.config.get_text_config(decoder=True) + max_cache_len = prompt_length + num_canvases * canvas_length if use_static_cache: - text_config = self.model.config.get_text_config(decoder=True) - max_cache_len = prompt_length + num_canvases * canvas_length past_key_values = StaticCache(config=text_config, max_cache_len=max_cache_len) + else: + past_key_values = DynamicCache(config=text_config) progress_bar = tqdm(range(num_canvases), **getattr(self, "_progress_bar_config", {})) for _ in progress_bar: cur_len = cur_input_ids.shape[1] decoder_position_ids = torch.arange(cur_len, cur_len + canvas_length, device=device).unsqueeze(0) - mask_mapping = None - if use_static_cache: - # Encode the tokens not yet in the cache (the whole prompt on the first block, the last committed - # canvas afterwards), then build the fixed-size 4D decoder mask once for this block (outside any - # compiled region, so the compiled decoder never constructs masks). - cached_len = past_key_values.get_seq_length() - new_tokens = cur_input_ids[:, cached_len:] - self.model.model.encoder( - input_ids=new_tokens, - attention_mask=cur_attention_mask, - past_key_values=past_key_values, - position_ids=torch.arange(cached_len, cur_len, device=device).unsqueeze(0), - ) - decoder_attention_mask = torch.zeros( - (batch_size, max_cache_len + canvas_length), dtype=torch.bool, device=device - ) - decoder_attention_mask[:, :cur_len] = cur_attention_mask.bool() - decoder_attention_mask[:, -canvas_length:] = True - mask_mapping = self.model.model.decoder.create_diffusion_decoder_attention_mask( - config=text_config, - inputs_embeds=torch.empty((batch_size, canvas_length, 0), device=device), - past_key_values=past_key_values, - decoder_attention_mask=decoder_attention_mask, - ) - else: - decoder_attention_mask = F.pad(cur_attention_mask.bool(), (0, canvas_length), value=True) + # Encode the tokens not yet in the cache (the whole prompt on the first block, the last committed canvas + # afterwards), so the decoder reuses the encoder KV cache instead of re-encoding the full sequence. + cached_len = past_key_values.get_seq_length() + self.model.model.encoder( + input_ids=cur_input_ids[:, cached_len:], + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + position_ids=torch.arange(cached_len, cur_len, device=device).unsqueeze(0), + ) + + # Build the 4D decoder mask once per block (outside any compiled region). A static cache spans its full + # buffer; a dynamic cache spans only the populated length. + cache_buffer_len = max_cache_len if use_static_cache else cur_len + decoder_attention_mask = torch.zeros( + (batch_size, cache_buffer_len + canvas_length), dtype=torch.bool, device=device + ) + decoder_attention_mask[:, :cur_len] = cur_attention_mask.bool() + decoder_attention_mask[:, -canvas_length:] = True + mask_mapping = self.model.model.decoder.create_diffusion_decoder_attention_mask( + config=text_config, + inputs_embeds=torch.empty((batch_size, canvas_length, 0), device=device), + past_key_values=past_key_values, + decoder_attention_mask=decoder_attention_mask, + ) # Start from a fully random canvas and denoise it; the scheduler resets its committed state at step 0. canvas = torch.randint(0, self.vocab_size, (batch_size, canvas_length), device=device, generator=generator) self_conditioning_logits = None for step_idx in range(num_inference_steps): - if use_static_cache: - logits = self.model( - decoder_input_ids=canvas, - past_key_values=past_key_values, - self_conditioning_logits=self_conditioning_logits, - decoder_attention_mask=mask_mapping, - decoder_position_ids=decoder_position_ids, - ).logits - else: - logits = self.model( - input_ids=cur_input_ids, - attention_mask=cur_attention_mask, - decoder_input_ids=canvas, - decoder_position_ids=decoder_position_ids, - decoder_attention_mask=decoder_attention_mask, - self_conditioning_logits=self_conditioning_logits, - ).logits + logits = self.model( + decoder_input_ids=canvas, + past_key_values=past_key_values, + self_conditioning_logits=self_conditioning_logits, + decoder_attention_mask=mask_mapping, + decoder_position_ids=decoder_position_ids, + ).logits self_conditioning_logits = logits scheduler_output = self.scheduler.step( diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index e73595982757..7b564eb0db37 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -5,9 +5,7 @@ from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline -class _DummyModelOutput: - def __init__(self, logits): - self.logits = logits +# --- Lightweight stand-in for input-validation tests that never reach the model --- class _DummyTextConfig: @@ -21,110 +19,41 @@ def __init__(self, canvas_length: int, vocab_size: int): self.canvas_length = int(canvas_length) self._text_config = _DummyTextConfig(vocab_size) - def get_text_config(self): + def get_text_config(self, decoder: bool = False): return self._text_config -class _DummyBlockDiffusionModel(torch.nn.Module): - """Stand-in for `DiffusionGemmaForBlockDiffusion`: returns logits over the decoder canvas.""" - +class _DummyModel(torch.nn.Module): def __init__(self, vocab_size: int = 32, canvas_length: int = 8): super().__init__() - self.vocab_size = int(vocab_size) self.config = _DummyConfig(canvas_length, vocab_size) - self.register_buffer("_device_anchor", torch.empty(0)) - - @property - def dtype(self): - return torch.float32 - - @property - def device(self): - return self._device_anchor.device - def forward(self, input_ids=None, decoder_input_ids=None, **kwargs): - batch_size, canvas_len = decoder_input_ids.shape - device = decoder_input_ids.device - logits = torch.zeros((batch_size, canvas_len, self.vocab_size), device=device, dtype=torch.float32) - # Make confidence vary with canvas position so the commit quota is deterministic. - positions = torch.arange(canvas_len, device=device, dtype=torch.float32).view(1, canvas_len, 1) - token_ids = (torch.arange(canvas_len, device=device) % (self.vocab_size - 2)).view(1, canvas_len, 1) - logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1) - return _DummyModelOutput(logits=logits) +def _make_dummy_pipeline(processor=None, canvas_length: int = 8): + model = _DummyModel(vocab_size=32, canvas_length=canvas_length) + return DiffusionGemmaPipeline(model=model, scheduler=BlockRefinementScheduler(), processor=processor) -def _make_pipeline(processor=None, canvas_length: int = 8): - model = _DummyBlockDiffusionModel(vocab_size=32, canvas_length=canvas_length) - scheduler = BlockRefinementScheduler() - return DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor) +class DiffusionGemmaPipelineInputTest(unittest.TestCase): + """Input validation and prompt encoding, which short-circuit before the model is called.""" -class DiffusionGemmaPipelineTest(unittest.TestCase): - def test_pipeline_runs(self): - pipe = _make_pipeline().to("cpu") - input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long) - out = pipe( - input_ids=input_ids, - gen_length=24, # 3 canvases of length 8 - num_inference_steps=8, - temperature=0.0, - eos_early_stop=False, - output_type="seq", - ) - self.assertEqual(out.sequences.shape, (2, 24)) - self.assertIsNone(out.texts) - - def test_pipeline_return_tuple(self): - pipe = _make_pipeline().to("cpu") - sequences, texts = pipe( - input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), - gen_length=16, - num_inference_steps=4, - eos_early_stop=False, - output_type="seq", - return_dict=False, - ) - self.assertEqual(sequences.shape, (1, 16)) - self.assertIsNone(texts) - - def test_output_type_text_with_processor(self): - processor = type( - "Proc", - (), - { - "tokenizer": type("Tok", (), {"eos_token_id": None})(), - "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], - }, - )() - pipe = _make_pipeline(processor=processor).to("cpu") - out = pipe( - input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), - gen_length=16, - num_inference_steps=4, - eos_early_stop=False, - output_type="text", - ) - self.assertIsNotNone(out.texts) - self.assertEqual(len(out.texts), 1) - self.assertTrue(out.texts[0].startswith("decoded_")) + def test_no_inputs_raises(self): + pipe = _make_dummy_pipeline() + with self.assertRaises(ValueError): + pipe(gen_length=8, num_inference_steps=2, output_type="seq") def test_output_type_invalid_raises(self): - pipe = _make_pipeline().to("cpu") + pipe = _make_dummy_pipeline() with self.assertRaises(ValueError): - pipe( - input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), - gen_length=8, - num_inference_steps=2, - output_type="invalid", - ) + pipe(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), gen_length=8, output_type="invalid") - def test_no_inputs_raises(self): - pipe = _make_pipeline().to("cpu") + def test_prompt_and_messages_together_raises(self): + pipe = _make_dummy_pipeline() with self.assertRaises(ValueError): - pipe(gen_length=8, num_inference_steps=2, output_type="seq") + pipe(prompt="hi", messages=[{"role": "user", "content": "hi"}], gen_length=8, output_type="seq") def test_prepare_input_ids_from_1d_tensor(self): - pipe = _make_pipeline() + pipe = _make_dummy_pipeline() ids = torch.tensor([1, 2, 3], dtype=torch.long) result_ids, result_mask = pipe._prepare_input_ids( prompt=None, messages=None, input_ids=ids, attention_mask=None, add_generation_prompt=False @@ -133,72 +62,90 @@ def test_prepare_input_ids_from_1d_tensor(self): self.assertEqual(result_mask.shape, (1, 3)) self.assertTrue((result_mask == 1).all().item()) - def test_callback_receives_advertised_keys(self): - observed: list[str] = [] - def cb(pipe, step, timestep, kwargs): - observed.extend(sorted(kwargs.keys())) - return {} +# --- End-to-end generation: the prefill-once path drives the real encoder/decoder, so it needs the tiny model --- + +_MODEL_ID = "trl-internal-testing/tiny-DiffusionGemmaForBlockDiffusion" + - pipe = _make_pipeline().to("cpu") - keys = list(pipe._callback_tensor_inputs) - pipe( - input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), - gen_length=8, +def _load_pipeline(test): + try: + from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion + except ImportError as e: + test.skipTest(f"transformers without DiffusionGemma: {e}") + try: + model = DiffusionGemmaForBlockDiffusion.from_pretrained(_MODEL_ID, dtype=torch.float32).eval() + processor = AutoProcessor.from_pretrained(_MODEL_ID) + except Exception as e: # noqa: BLE001 - offline / hub errors should skip, not fail + test.skipTest(f"tiny DiffusionGemma checkpoint unavailable: {e}") + pipe = DiffusionGemmaPipeline(model=model, scheduler=BlockRefinementScheduler(), processor=processor) + pipe.set_progress_bar_config(disable=True) + return pipe, model.config.canvas_length + + +class DiffusionGemmaPipelineTest(unittest.TestCase): + def setUp(self): + self.pipe, self.canvas_length = _load_pipeline(self) + self.messages = [{"role": "user", "content": "Name a color."}] + + def test_generate_seq_shape(self): + out = self.pipe( + messages=self.messages, + gen_length=self.canvas_length * 2, num_inference_steps=4, + temperature=0.0, eos_early_stop=False, output_type="seq", - callback_on_step_end=cb, - callback_on_step_end_tensor_inputs=keys, ) - self.assertEqual(set(observed), set(keys)) + self.assertEqual(out.sequences.shape, (1, self.canvas_length * 2)) + self.assertIsNone(out.texts) + + def test_generate_text_and_return_tuple(self): + sequences, texts = self.pipe( + messages=self.messages, + gen_length=self.canvas_length, + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="text", + return_dict=False, + ) + self.assertEqual(sequences.shape, (1, self.canvas_length)) + self.assertEqual(len(texts), 1) - def test_progress_bar_disable_is_preserved_after_call(self): - pipe = _make_pipeline().to("cpu") - pipe.set_progress_bar_config(disable=True) - before = dict(pipe._progress_bar_config) - pipe( - input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), - gen_length=8, + def test_callback_receives_advertised_keys(self): + observed: list[str] = [] + + def callback(pipe, step, timestep, callback_kwargs): + observed.extend(sorted(callback_kwargs.keys())) + return {} + + keys = list(self.pipe._callback_tensor_inputs) + self.pipe( + messages=self.messages, + gen_length=self.canvas_length, num_inference_steps=2, + temperature=0.0, eos_early_stop=False, output_type="seq", + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=keys, ) - self.assertEqual(pipe._progress_bar_config, before) - - -class DiffusionGemmaStaticCacheTest(unittest.TestCase): - """The static-cache path uses the real model internals (encoder prefill + `StaticCache`), so it needs the tiny - checkpoint rather than a stand-in. Skips when the model can't be fetched (e.g. offline CI).""" - - def _load_pipeline(self): - try: - from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion - except ImportError as e: - self.skipTest(f"transformers without DiffusionGemma: {e}") - model_id = "trl-internal-testing/tiny-DiffusionGemmaForBlockDiffusion" - try: - model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=torch.float32).eval() - processor = AutoProcessor.from_pretrained(model_id) - except Exception as e: # noqa: BLE001 - offline / hub errors should skip, not fail - self.skipTest(f"tiny DiffusionGemma checkpoint unavailable: {e}") - pipe = DiffusionGemmaPipeline(model=model, scheduler=BlockRefinementScheduler(), processor=processor) - pipe.set_progress_bar_config(disable=True) - return pipe, model.config.canvas_length + self.assertEqual(set(observed), set(keys)) def test_static_cache_matches_dynamic(self): - pipe, canvas_length = self._load_pipeline() kwargs = { - "messages": [{"role": "user", "content": "Name a color."}], - "gen_length": canvas_length * 2, # two canvases -> exercises the cache extension between blocks + "messages": self.messages, + "gen_length": self.canvas_length * 2, # two canvases -> exercises the cache extension between blocks "num_inference_steps": 4, "temperature": 0.0, "eos_early_stop": False, "output_type": "seq", } - dynamic = pipe(generator=torch.Generator().manual_seed(0), **kwargs).sequences - static = pipe(generator=torch.Generator().manual_seed(0), cache_implementation="static", **kwargs).sequences - self.assertEqual(dynamic.shape, (1, canvas_length * 2)) + dynamic = self.pipe(generator=torch.Generator().manual_seed(0), **kwargs).sequences + static = self.pipe( + generator=torch.Generator().manual_seed(0), cache_implementation="static", **kwargs + ).sequences self.assertTrue(torch.equal(dynamic, static)) From 4ce203e835f8f6bc175b47bf275fcb20a37825d1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 18:06:13 +0200 Subject: [PATCH 08/17] Support image prompts by forwarding pixel_values to the encoder prefill --- .../pipeline_diffusion_gemma.py | 40 ++++++++++++++++--- .../diffusion_gemma/test_diffusion_gemma.py | 21 +++++++++- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index ee9793d5b1c0..e08f1fee81b3 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -105,8 +105,14 @@ def _prepare_input_ids( input_ids: torch.LongTensor | None, attention_mask: torch.LongTensor | None, add_generation_prompt: bool, - ) -> tuple[torch.LongTensor, torch.LongTensor]: - """Convert prompt/messages/input_ids to `(input_ids, attention_mask)` tensors of shape `[batch, seq]`.""" + pixel_values: torch.FloatTensor | None = None, + image_position_ids: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + ) -> tuple[torch.LongTensor, torch.LongTensor, dict[str, torch.Tensor]]: + """Convert prompt/messages/input_ids to `(input_ids, attention_mask, multimodal_inputs)`, where + `multimodal_inputs` holds any image tensors (`pixel_values`, `image_position_ids`, `mm_token_type_ids`) to + forward to the encoder prefill.""" + multimodal_keys = ("pixel_values", "image_position_ids", "mm_token_type_ids") if input_ids is not None: if input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) @@ -114,7 +120,13 @@ def _prepare_input_ids( attention_mask = torch.ones_like(input_ids, dtype=torch.long) elif attention_mask.ndim == 1: attention_mask = attention_mask.unsqueeze(0) - return input_ids, attention_mask.to(dtype=torch.long) + multimodal_inputs = { + "pixel_values": pixel_values, + "image_position_ids": image_position_ids, + "mm_token_type_ids": mm_token_type_ids, + } + multimodal_inputs = {k: v for k, v in multimodal_inputs.items() if v is not None} + return input_ids, attention_mask.to(dtype=torch.long), multimodal_inputs if self.processor is None: raise ValueError("`processor` is required when `input_ids` is not provided.") @@ -136,7 +148,8 @@ def _prepare_input_ids( mask = encoded.get("attention_mask") if mask is None: mask = torch.ones_like(ids, dtype=torch.long) - return ids, mask.to(dtype=torch.long) + multimodal_inputs = {k: encoded[k] for k in multimodal_keys if k in encoded} + return ids, mask.to(dtype=torch.long), multimodal_inputs def check_inputs( self, @@ -176,6 +189,9 @@ def __call__( messages: list[dict[str, str]] | None = None, input_ids: torch.LongTensor | None = None, attention_mask: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_position_ids: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, add_generation_prompt: bool = True, gen_length: int = 256, num_inference_steps: int = 48, @@ -209,6 +225,14 @@ def __call__( Pre-tokenized prompt IDs. Takes precedence over `prompt` and `messages`. attention_mask (`torch.LongTensor`, *optional*): Per-token mask matching `input_ids`. Only used when `input_ids` is provided. + pixel_values (`torch.FloatTensor`, *optional*): + Image features for multimodal prompts, forwarded to the encoder prefill. When the prompt is built from + `messages` with image content, the processor produces these (and `image_position_ids` / + `mm_token_type_ids`) automatically; pass them explicitly only alongside pre-tokenized `input_ids`. + image_position_ids (`torch.LongTensor`, *optional*): + Patch position coordinates for `pixel_values`. + mm_token_type_ids (`torch.LongTensor`, *optional*): + Per-token modality ids marking image vs text positions for `pixel_values`. add_generation_prompt (`bool`, defaults to `True`): Whether to add the generation prompt when applying the chat template. gen_length (`int`, defaults to `256`): @@ -269,17 +293,21 @@ def __call__( callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) - prompt_ids, prompt_attention_mask = self._prepare_input_ids( + prompt_ids, prompt_attention_mask, multimodal_inputs = self._prepare_input_ids( prompt=prompt, messages=messages, input_ids=input_ids, attention_mask=attention_mask, add_generation_prompt=add_generation_prompt, + pixel_values=pixel_values, + image_position_ids=image_position_ids, + mm_token_type_ids=mm_token_type_ids, ) device = self._execution_device prompt_ids = prompt_ids.to(device=device) prompt_attention_mask = prompt_attention_mask.to(device=device) + multimodal_inputs = {k: v.to(device=device) for k, v in multimodal_inputs.items()} batch_size, prompt_length = prompt_ids.shape if eos_token_id is None: @@ -319,6 +347,8 @@ def __call__( attention_mask=cur_attention_mask, past_key_values=past_key_values, position_ids=torch.arange(cached_len, cur_len, device=device).unsqueeze(0), + # Image tensors are consumed by the prompt prefill only; later blocks encode text-only canvases. + **(multimodal_inputs if cached_len == 0 else {}), ) # Build the 4D decoder mask once per block (outside any compiled region). A static cache spans its full diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index 7b564eb0db37..eaa121728cab 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -55,12 +55,13 @@ def test_prompt_and_messages_together_raises(self): def test_prepare_input_ids_from_1d_tensor(self): pipe = _make_dummy_pipeline() ids = torch.tensor([1, 2, 3], dtype=torch.long) - result_ids, result_mask = pipe._prepare_input_ids( + result_ids, result_mask, multimodal = pipe._prepare_input_ids( prompt=None, messages=None, input_ids=ids, attention_mask=None, add_generation_prompt=False ) self.assertEqual(result_ids.shape, (1, 3)) self.assertEqual(result_mask.shape, (1, 3)) self.assertTrue((result_mask == 1).all().item()) + self.assertEqual(multimodal, {}) # --- End-to-end generation: the prefill-once path drives the real encoder/decoder, so it needs the tiny model --- @@ -133,6 +134,24 @@ def callback(pipe, step, timestep, callback_kwargs): ) self.assertEqual(set(observed), set(keys)) + def test_generate_with_image(self): + import numpy as np + from PIL import Image + + image = Image.fromarray((np.random.rand(64, 64, 3) * 255).astype("uint8")) + messages = [ + {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What?"}]} + ] + out = self.pipe( + messages=messages, + gen_length=self.canvas_length, + num_inference_steps=2, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + def test_static_cache_matches_dynamic(self): kwargs = { "messages": self.messages, From 9d1df715f2cf369d75f66c2b5cb5deeafda7a46c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 18:11:55 +0200 Subject: [PATCH 09/17] Restyle docstrings to satisfy doc-builder --- .../pipelines/diffusion_gemma/pipeline_diffusion_gemma.py | 6 +++--- src/diffusers/schedulers/scheduling_block_refinement.py | 4 ++-- src/diffusers/schedulers/scheduling_discrete_ddim.py | 8 ++++---- src/diffusers/schedulers/scheduling_entropy_bound.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index e08f1fee81b3..c6bee0ff9873 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -251,9 +251,9 @@ def __call__( Confidence threshold for re-editing already committed tokens. Defaults to the scheduler's value. cache_implementation (`str`, *optional*): Set to `"static"` to prefill the encoder once per block into a persistent `StaticCache` and run the - decoder against it with fixed shapes, instead of re-encoding the full sequence on every step. The - fixed shapes also let you compile the decoder, e.g. - `pipe.model.model.decoder = torch.compile(pipe.model.model.decoder, fullgraph=True)`. + decoder against it with fixed shapes, instead of re-encoding the full sequence on every step. The fixed + shapes also let you compile the decoder, e.g. `pipe.model.model.decoder = + torch.compile(pipe.model.model.decoder, fullgraph=True)`. eos_early_stop (`bool`, defaults to `True`): Whether to stop generating further canvases once every sequence has emitted EOS. eos_token_id (`int`, *optional*): diff --git a/src/diffusers/schedulers/scheduling_block_refinement.py b/src/diffusers/schedulers/scheduling_block_refinement.py index 2a7c73769755..2844fb8f0d29 100644 --- a/src/diffusers/schedulers/scheduling_block_refinement.py +++ b/src/diffusers/schedulers/scheduling_block_refinement.py @@ -209,8 +209,8 @@ def step( Current block token IDs (contains mask tokens for uncommitted positions in the mask-based mode). mask_token_id (`int`, *optional*): Token ID used for masked positions. When `None`, the scheduler runs in uniform corruption mode: it - tracks committed positions internally (resetting at `timestep == 0`) and renoises the uncommitted - ones with uniformly random tokens, matching DiffusionGemma's block refinement sampler. + tracks committed positions internally (resetting at `timestep == 0`) and renoises the uncommitted ones + with uniformly random tokens, matching DiffusionGemma's block refinement sampler. temperature (`float`): Sampling temperature. top_p (`float`, *optional*): diff --git a/src/diffusers/schedulers/scheduling_discrete_ddim.py b/src/diffusers/schedulers/scheduling_discrete_ddim.py index 51a890da28a7..22b02b8c2496 100644 --- a/src/diffusers/schedulers/scheduling_discrete_ddim.py +++ b/src/diffusers/schedulers/scheduling_discrete_ddim.py @@ -49,9 +49,9 @@ class DiscreteDDIMScheduler(SchedulerMixin, ConfigMixin): On the linear schedule the survival probability of a clean token at time `t` is `alpha(t) = 1 - t`. One denoising step from time `t` to `s < t` samples every block position from the exact posterior `q(x_s | x_t, x0)`, which for - the uniform kernel decomposes into three routes: jump to the predicted clean token `x0`, stay on the current - token, or jump to a uniformly random token. Unlike masked diffusion, there is no mask token; uncommitted positions - carry random tokens. + the uniform kernel decomposes into three routes: jump to the predicted clean token `x0`, stay on the current token, + or jump to a uniformly random token. Unlike masked diffusion, there is no mask token; uncommitted positions carry + random tokens. Args: num_inference_steps (`int`, defaults to 32): @@ -111,7 +111,7 @@ def step( With `a = alpha_t / alpha_s` (survival probability from `s` to `t`) and `b = alpha_s`, the posterior mass of each route is - clean: `b * (1 - a) / K + a * b * 1[x_t = x0]`, stay: `a * (1 - b) / K`, noise: `(1 - a) * (1 - b) / K`, + clean: `b * (1 - a) / K + a * b * 1[x_t = x0]`, stay: `a * (1 - b) / K`, noise: `(1 - a) * (1 - b) / K`, so the last step (`b = 1`) deterministically commits the predicted clean tokens. diff --git a/src/diffusers/schedulers/scheduling_entropy_bound.py b/src/diffusers/schedulers/scheduling_entropy_bound.py index fe54eabb403d..3a6bb50e2f7d 100644 --- a/src/diffusers/schedulers/scheduling_entropy_bound.py +++ b/src/diffusers/schedulers/scheduling_entropy_bound.py @@ -49,9 +49,9 @@ class EntropyBoundScheduler(SchedulerMixin, ConfigMixin): """ Entropy bound scheduler for the uniform corruption process. - At each step the scheduler samples a candidate token per position and accepts the `k` lowest-entropy positions - such that `sum_i^k entropy_i - max(entropy_1, ..., entropy_k) <= entropy_bound`. The left-hand side upper-bounds - the joint mutual information between the accepted tokens, so they are approximately independent. Accepted positions + At each step the scheduler samples a candidate token per position and accepts the `k` lowest-entropy positions such + that `sum_i^k entropy_i - max(entropy_1, ..., entropy_k) <= entropy_bound`. The left-hand side upper-bounds the + joint mutual information between the accepted tokens, so they are approximately independent. Accepted positions keep their sampled token; the rest are renoised with uniformly random tokens (there is no mask token). Proposed in "Beyond Next-Token Prediction" (https://huggingface.co/papers/2505.24857). From 18651f55739ce0c8c0016c4876084700041e0d29 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 18:22:06 +0200 Subject: [PATCH 10/17] Sort the new scheduler and pipeline exports --- src/diffusers/__init__.py | 8 ++++---- src/diffusers/schedulers/__init__.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 63ec3643d45e..27978c1f3d8b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1247,10 +1247,6 @@ AmusedScheduler, BlockRefinementScheduler, BlockRefinementSchedulerOutput, - DiscreteDDIMScheduler, - DiscreteDDIMSchedulerOutput, - EntropyBoundScheduler, - EntropyBoundSchedulerOutput, CMStochasticIterativeScheduler, CogVideoXDDIMScheduler, CogVideoXDPMScheduler, @@ -1261,11 +1257,15 @@ DDPMScheduler, DDPMWuerstchenScheduler, DEISMultistepScheduler, + DiscreteDDIMScheduler, + DiscreteDDIMSchedulerOutput, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EDMDPMSolverMultistepScheduler, EDMEulerScheduler, + EntropyBoundScheduler, + EntropyBoundSchedulerOutput, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, FlowMapEulerDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 440f0b91ded9..b7332480f822 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -150,8 +150,6 @@ from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput - from .scheduling_discrete_ddim import DiscreteDDIMScheduler, DiscreteDDIMSchedulerOutput - from .scheduling_entropy_bound import EntropyBoundScheduler, EntropyBoundSchedulerOutput from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler @@ -162,12 +160,14 @@ from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_discrete_ddim import DiscreteDDIMScheduler, DiscreteDDIMSchedulerOutput from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_edm_dpmsolver_multistep import EDMDPMSolverMultistepScheduler from .scheduling_edm_euler import EDMEulerScheduler + from .scheduling_entropy_bound import EntropyBoundScheduler, EntropyBoundSchedulerOutput from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_flow_map_euler_discrete import FlowMapEulerDiscreteScheduler From 1d1efe79e0059cdbe4f5fb8f0d5200ca258c4d5b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 18:39:12 +0200 Subject: [PATCH 11/17] Let any of the three schedulers drive the pipeline --- .../pipeline_diffusion_gemma.py | 39 ++++++++++++------- .../diffusion_gemma/test_diffusion_gemma.py | 15 +++++++ 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index c6bee0ff9873..6e4142908c76 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -14,6 +14,7 @@ from __future__ import annotations +import inspect from dataclasses import dataclass from typing import Any, Callable @@ -23,7 +24,7 @@ from transformers import DynamicCache, StaticCache from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...schedulers import BlockRefinementScheduler +from ...schedulers import BlockRefinementScheduler, DiscreteDDIMScheduler, EntropyBoundScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -72,7 +73,7 @@ class DiffusionGemmaPipeline(DiffusionPipeline): """ model: Any - scheduler: BlockRefinementScheduler + scheduler: BlockRefinementScheduler | DiscreteDDIMScheduler | EntropyBoundScheduler processor: Any _callback_tensor_inputs = ["canvas", "logits"] @@ -80,7 +81,7 @@ class DiffusionGemmaPipeline(DiffusionPipeline): def __init__( self, model: Any, - scheduler: BlockRefinementScheduler, + scheduler: BlockRefinementScheduler | DiscreteDDIMScheduler | EntropyBoundScheduler, processor: Any | None = None, ): super().__init__() @@ -315,7 +316,13 @@ def __call__( canvas_length = self.canvas_length num_canvases = (gen_length + canvas_length - 1) // canvas_length - self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=canvas_length) + # Only `BlockRefinementScheduler` takes a per-call `block_length`; the DiscreteDDIM/EntropyBound schedulers do + # not, so we pass scheduler-specific kwargs by signature. + set_timesteps_kwargs = {"device": device} + if "block_length" in inspect.signature(self.scheduler.set_timesteps).parameters: + set_timesteps_kwargs["block_length"] = canvas_length + self.scheduler.set_timesteps(num_inference_steps, **set_timesteps_kwargs) + step_param_names = set(inspect.signature(self.scheduler.step).parameters) self._num_timesteps = num_inference_steps * num_canvases cur_input_ids = prompt_ids @@ -380,18 +387,20 @@ def __call__( ).logits self_conditioning_logits = logits + # Forward only the knobs the chosen scheduler accepts (block refinement takes thresholds/top-k, + # discrete DDIM and entropy bound do not), so any of the schedulers can drive the pipeline. + step_kwargs = { + "mask_token_id": None, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "threshold": threshold, + "editing_threshold": editing_threshold, + "generator": generator, + } + step_kwargs = {k: v for k, v in step_kwargs.items() if k in step_param_names} scheduler_output = self.scheduler.step( - model_output=logits, - timestep=step_idx, - sample=canvas, - mask_token_id=None, - temperature=temperature, - top_p=top_p, - top_k=top_k, - threshold=threshold, - editing_threshold=editing_threshold, - generator=generator, - return_dict=True, + model_output=logits, timestep=step_idx, sample=canvas, return_dict=True, **step_kwargs ) canvas = scheduler_output.prev_sample diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index eaa121728cab..21601fad7574 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -152,6 +152,21 @@ def test_generate_with_image(self): ) self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + def test_schedulers_are_interchangeable(self): + from diffusers import DiscreteDDIMScheduler, EntropyBoundScheduler + + for scheduler in (DiscreteDDIMScheduler(), EntropyBoundScheduler(entropy_bound=0.1)): + self.pipe.scheduler = scheduler + out = self.pipe( + messages=self.messages, + gen_length=self.canvas_length, + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + def test_static_cache_matches_dynamic(self): kwargs = { "messages": self.messages, From 8a9ffcfdcd55346d0aaa823234ea927d6b35b6eb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 18:39:12 +0200 Subject: [PATCH 12/17] Document the schedulers and updated defaults in the pipeline docs --- .../en/api/pipelines/diffusion_gemma.md | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md index 6061844b25b4..615874f2b7b8 100644 --- a/docs/source/en/api/pipelines/diffusion_gemma.md +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -36,17 +36,42 @@ pipe = DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=proces output = pipe( prompt="Why is the sky blue?", gen_length=256, - num_inference_steps=32, + num_inference_steps=48, temperature=0.0, ) print(output.texts[0]) ``` +`num_inference_steps` is the number of denoising steps per canvas (48 matches the released checkpoint); fewer steps are +faster but lower quality. For multimodal prompts, pass image content in `messages` and the processor's `pixel_values` +are forwarded to the model automatically. + +## Schedulers + +The scheduler is the sampler that denoises each canvas, and it is interchangeable: swap it to change the sampling +strategy without touching anything else. Three schedulers are available: + +- `BlockRefinementScheduler` (default): commits the most confident tokens each step (above `threshold`, plus an even + per-step quota) and renoises the rest. `editing_threshold` additionally lets it re-edit already committed tokens. +- `DiscreteDDIMScheduler`: samples each position from the exact discrete posterior of the uniform corruption process + (D3PM). It is parameter free, and the final step deterministically commits the predicted tokens. +- `EntropyBoundScheduler`: commits the lowest-entropy positions whose joint entropy stays under `entropy_bound`, so + roughly independent tokens are accepted together. + +```py +from diffusers import DiscreteDDIMScheduler, EntropyBoundScheduler + +pipe.scheduler = DiscreteDDIMScheduler() +# or: pipe.scheduler = EntropyBoundScheduler(entropy_bound=0.1) +output = pipe(prompt="Why is the sky blue?", gen_length=256, num_inference_steps=48) +print(output.texts[0]) +``` + ## Static cache and compilation -By default the pipeline re-encodes the prompt on every denoising step. Pass `cache_implementation="static"` to instead -prefill the encoder once per block into a persistent `StaticCache` and run the decoder against it with fixed shapes. -The fixed shapes let you `torch.compile` the decoder for a further speedup: +The pipeline prefills the encoder once per block into a reusable cache (a `DynamicCache` by default). Pass +`cache_implementation="static"` to use a fixed-shape `StaticCache` instead, whose shapes let you `torch.compile` the +decoder for a further speedup: ```py pipe.model.model.decoder = torch.compile(pipe.model.model.decoder, fullgraph=True) From 73448d9e1bb4e7f110caaa2d9f4297632511340a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jun 2026 18:39:12 +0200 Subject: [PATCH 13/17] Sort the scheduler dummy objects --- src/diffusers/utils/dummy_pt_objects.py | 38 ++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 823123988830..66dde8da8728 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2882,7 +2882,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DiscreteDDIMScheduler(metaclass=DummyObject): +class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2897,7 +2897,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DiscreteDDIMSchedulerOutput(metaclass=DummyObject): +class CogVideoXDDIMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2912,7 +2912,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class EntropyBoundScheduler(metaclass=DummyObject): +class CogVideoXDPMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2927,7 +2927,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class EntropyBoundSchedulerOutput(metaclass=DummyObject): +class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2942,7 +2942,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class CMStochasticIterativeScheduler(metaclass=DummyObject): +class DDIMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2957,7 +2957,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class CogVideoXDDIMScheduler(metaclass=DummyObject): +class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2972,7 +2972,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class CogVideoXDPMScheduler(metaclass=DummyObject): +class DDPMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2987,7 +2987,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDIMInverseScheduler(metaclass=DummyObject): +class DDPMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3002,7 +3002,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDIMParallelScheduler(metaclass=DummyObject): +class DDPMWuerstchenScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3017,7 +3017,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDIMScheduler(metaclass=DummyObject): +class DEISMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3032,7 +3032,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDPMParallelScheduler(metaclass=DummyObject): +class DiscreteDDIMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3047,7 +3047,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDPMScheduler(metaclass=DummyObject): +class DiscreteDDIMSchedulerOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3062,7 +3062,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DDPMWuerstchenScheduler(metaclass=DummyObject): +class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3077,7 +3077,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DEISMultistepScheduler(metaclass=DummyObject): +class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3092,7 +3092,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): +class DPMSolverSinglestepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3107,7 +3107,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DPMSolverMultistepScheduler(metaclass=DummyObject): +class EDMDPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3122,7 +3122,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DPMSolverSinglestepScheduler(metaclass=DummyObject): +class EDMEulerScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3137,7 +3137,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class EDMDPMSolverMultistepScheduler(metaclass=DummyObject): +class EntropyBoundScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -3152,7 +3152,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class EDMEulerScheduler(metaclass=DummyObject): +class EntropyBoundSchedulerOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 04dd9b904ac455021ebc11c4185bcdd0bf6af477 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 19 Jun 2026 09:35:27 +0200 Subject: [PATCH 14/17] Set scheduler sampling knobs on the scheduler config, not the pipeline call --- .../en/api/pipelines/diffusion_gemma.md | 9 ++++++ .../pipeline_diffusion_gemma.py | 30 ++++--------------- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md index 615874f2b7b8..49ee18c43d46 100644 --- a/docs/source/en/api/pipelines/diffusion_gemma.md +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -67,6 +67,15 @@ output = pipe(prompt="Why is the sky blue?", gen_length=256, num_inference_steps print(output.texts[0]) ``` +Scheduler-specific sampling knobs (the block-refinement `threshold`/`top_k`, the entropy bound, ...) are set on the +scheduler config: + +```py +from diffusers import BlockRefinementScheduler + +pipe.scheduler = BlockRefinementScheduler.from_config(pipe.scheduler.config, threshold=0.9) +``` + ## Static cache and compilation The pipeline prefills the encoder once per block into a reusable cache (a `DynamicCache` by default). Pass diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index 6e4142908c76..ed106b0c53ff 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -197,10 +197,6 @@ def __call__( gen_length: int = 256, num_inference_steps: int = 48, temperature: float = 0.0, - top_p: float | None = None, - top_k: int | None = None, - threshold: float | None = None, - editing_threshold: float | None = None, cache_implementation: str | None = None, eos_early_stop: bool = True, eos_token_id: int | None = None, @@ -241,15 +237,9 @@ def __call__( num_inference_steps (`int`, defaults to `48`): Number of denoising steps per canvas. temperature (`float`, defaults to `0.0`): - Sampling temperature. `0.0` is greedy. - top_p (`float`, *optional*): - Nucleus sampling cutoff. - top_k (`int`, *optional*): - Top-k sampling cutoff. - threshold (`float`, *optional*): - Confidence threshold for committing tokens. Defaults to the scheduler's configured value. - editing_threshold (`float`, *optional*): - Confidence threshold for re-editing already committed tokens. Defaults to the scheduler's value. + Sampling temperature. `0.0` is greedy. Other sampling knobs (e.g. `top_k`, `threshold`) are scheduler + config; set them on the scheduler, e.g. `pipe.scheduler = + BlockRefinementScheduler.from_config(pipe.scheduler.config, top_k=...)`. cache_implementation (`str`, *optional*): Set to `"static"` to prefill the encoder once per block into a persistent `StaticCache` and run the decoder against it with fixed shapes, instead of re-encoding the full sequence on every step. The fixed @@ -387,17 +377,9 @@ def __call__( ).logits self_conditioning_logits = logits - # Forward only the knobs the chosen scheduler accepts (block refinement takes thresholds/top-k, - # discrete DDIM and entropy bound do not), so any of the schedulers can drive the pipeline. - step_kwargs = { - "mask_token_id": None, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "threshold": threshold, - "editing_threshold": editing_threshold, - "generator": generator, - } + # Pass only the kwargs the chosen scheduler accepts, so any of the schedulers can drive the pipeline. + # Per-scheduler sampling knobs (thresholds, top-k, ...) live on the scheduler config, not here. + step_kwargs = {"mask_token_id": None, "temperature": temperature, "generator": generator} step_kwargs = {k: v for k, v in step_kwargs.items() if k in step_param_names} scheduler_output = self.scheduler.step( model_output=logits, timestep=step_idx, sample=canvas, return_dict=True, **step_kwargs From 0f0041da13697293efadff9bdda1fbaa7230247b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 19 Jun 2026 18:14:41 +0200 Subject: [PATCH 15/17] Accept raw prompt/image/messages instead of pre-tokenized model inputs --- .../en/api/pipelines/diffusion_gemma.md | 4 +- .../pipeline_diffusion_gemma.py | 103 ++++++------------ .../diffusion_gemma/test_diffusion_gemma.py | 31 ++---- 3 files changed, 46 insertions(+), 92 deletions(-) diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md index 49ee18c43d46..dc2fb671fad2 100644 --- a/docs/source/en/api/pipelines/diffusion_gemma.md +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -43,8 +43,8 @@ print(output.texts[0]) ``` `num_inference_steps` is the number of denoising steps per canvas (48 matches the released checkpoint); fewer steps are -faster but lower quality. For multimodal prompts, pass image content in `messages` and the processor's `pixel_values` -are forwarded to the model automatically. +faster but lower quality. For multimodal prompts, pass an `image` alongside the `prompt` (or put the image content in a +raw `messages` conversation), and the processor turns it into the model's image inputs automatically. ## Schedulers diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index ed106b0c53ff..562a14d95baf 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -98,45 +98,29 @@ def num_timesteps(self): # --- Prompt encoding --- - def _prepare_input_ids( + def _prepare_inputs( self, *, prompt: str | list[str] | None, - messages: list[dict[str, str]] | None, - input_ids: torch.LongTensor | None, - attention_mask: torch.LongTensor | None, + messages: list[dict] | None, + image: Any | list[Any] | None, add_generation_prompt: bool, - pixel_values: torch.FloatTensor | None = None, - image_position_ids: torch.LongTensor | None = None, - mm_token_type_ids: torch.LongTensor | None = None, ) -> tuple[torch.LongTensor, torch.LongTensor, dict[str, torch.Tensor]]: - """Convert prompt/messages/input_ids to `(input_ids, attention_mask, multimodal_inputs)`, where - `multimodal_inputs` holds any image tensors (`pixel_values`, `image_position_ids`, `mm_token_type_ids`) to - forward to the encoder prefill.""" - multimodal_keys = ("pixel_values", "image_position_ids", "mm_token_type_ids") - if input_ids is not None: - if input_ids.ndim == 1: - input_ids = input_ids.unsqueeze(0) - if attention_mask is None: - attention_mask = torch.ones_like(input_ids, dtype=torch.long) - elif attention_mask.ndim == 1: - attention_mask = attention_mask.unsqueeze(0) - multimodal_inputs = { - "pixel_values": pixel_values, - "image_position_ids": image_position_ids, - "mm_token_type_ids": mm_token_type_ids, - } - multimodal_inputs = {k: v for k, v in multimodal_inputs.items() if v is not None} - return input_ids, attention_mask.to(dtype=torch.long), multimodal_inputs + """Tokenize a raw `prompt` (optionally with an `image`) or a raw `messages` conversation into + `(input_ids, attention_mask, multimodal_inputs)`, where `multimodal_inputs` holds the image tensors the + processor produced for the encoder prefill.""" - if self.processor is None: - raise ValueError("`processor` is required when `input_ids` is not provided.") + def build_content(text, img): + if img is None: + return text + return [{"type": "image", "image": img}, {"type": "text", "text": text}] if messages is None: if isinstance(prompt, list): - messages = [[{"role": "user", "content": p}] for p in prompt] + images = image if isinstance(image, list) else [image] * len(prompt) + messages = [[{"role": "user", "content": build_content(p, im)}] for p, im in zip(prompt, images)] else: - messages = [{"role": "user", "content": prompt}] + messages = [{"role": "user", "content": build_content(prompt, image)}] encoded = self.processor.apply_chat_template( messages, @@ -149,31 +133,31 @@ def _prepare_input_ids( mask = encoded.get("attention_mask") if mask is None: mask = torch.ones_like(ids, dtype=torch.long) + multimodal_keys = ("pixel_values", "image_position_ids", "mm_token_type_ids") multimodal_inputs = {k: encoded[k] for k in multimodal_keys if k in encoded} return ids, mask.to(dtype=torch.long), multimodal_inputs def check_inputs( self, prompt: str | list[str] | None, - messages: list[dict[str, str]] | None, - input_ids: torch.LongTensor | None, + messages: list[dict] | None, gen_length: int, num_inference_steps: int, output_type: str, callback_on_step_end_tensor_inputs: list[str] | None, ): - if prompt is None and messages is None and input_ids is None: - raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") - if prompt is not None and messages is not None: - raise ValueError("Provide either `prompt` or `messages`, not both.") - if (prompt is not None or messages is not None) and input_ids is None and self.processor is None: - raise ValueError("`processor` is required when `input_ids` is not provided.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") if gen_length <= 0: raise ValueError(f"`gen_length` must be > 0, got {gen_length}.") if num_inference_steps <= 0: raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") - if output_type not in {"seq", "text"}: - raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + if prompt is None and messages is None: + raise ValueError("Provide either `prompt` or `messages`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if self.processor is None: + raise ValueError("`processor` is required to encode the prompt.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -187,12 +171,8 @@ def check_inputs( def __call__( self, prompt: str | list[str] | None = None, - messages: list[dict[str, str]] | None = None, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - image_position_ids: torch.LongTensor | None = None, - mm_token_type_ids: torch.LongTensor | None = None, + messages: list[dict] | None = None, + image: Any | list[Any] | None = None, add_generation_prompt: bool = True, gen_length: int = 256, num_inference_steps: int = 48, @@ -214,22 +194,14 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - Prompt text, wrapped in a chat template and tokenized by the processor. - messages (`List[Dict[str, str]]`, *optional*): - Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over - `prompt`. Requires a processor with `apply_chat_template`. - input_ids (`torch.LongTensor`, *optional*): - Pre-tokenized prompt IDs. Takes precedence over `prompt` and `messages`. - attention_mask (`torch.LongTensor`, *optional*): - Per-token mask matching `input_ids`. Only used when `input_ids` is provided. - pixel_values (`torch.FloatTensor`, *optional*): - Image features for multimodal prompts, forwarded to the encoder prefill. When the prompt is built from - `messages` with image content, the processor produces these (and `image_position_ids` / - `mm_token_type_ids`) automatically; pass them explicitly only alongside pre-tokenized `input_ids`. - image_position_ids (`torch.LongTensor`, *optional*): - Patch position coordinates for `pixel_values`. - mm_token_type_ids (`torch.LongTensor`, *optional*): - Per-token modality ids marking image vs text positions for `pixel_values`. + Prompt text, wrapped in a chat template and tokenized by the processor. Provide either this or + `messages`. + messages (`List[Dict]`, *optional*): + A raw chat conversation to encode, e.g. `[{"role": "user", "content": "Hello"}]` or a multi-turn / + multimodal conversation. Use this instead of `prompt` for anything beyond a single user turn. + image (`PIL.Image.Image` or `List`, *optional*): + Image(s) to pair with `prompt` for multimodal generation; the processor turns them into the model's + image inputs. For richer layouts, put the image content directly in `messages`. add_generation_prompt (`bool`, defaults to `True`): Whether to add the generation prompt when applying the chat template. gen_length (`int`, defaults to `256`): @@ -277,22 +249,17 @@ def __call__( self.check_inputs( prompt=prompt, messages=messages, - input_ids=input_ids, gen_length=gen_length, num_inference_steps=num_inference_steps, output_type=output_type, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) - prompt_ids, prompt_attention_mask, multimodal_inputs = self._prepare_input_ids( + prompt_ids, prompt_attention_mask, multimodal_inputs = self._prepare_inputs( prompt=prompt, messages=messages, - input_ids=input_ids, - attention_mask=attention_mask, + image=image, add_generation_prompt=add_generation_prompt, - pixel_values=pixel_values, - image_position_ids=image_position_ids, - mm_token_type_ids=mm_token_type_ids, ) device = self._execution_device diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index 21601fad7574..76e322c923a6 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -45,24 +45,13 @@ def test_no_inputs_raises(self): def test_output_type_invalid_raises(self): pipe = _make_dummy_pipeline() with self.assertRaises(ValueError): - pipe(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), gen_length=8, output_type="invalid") + pipe(prompt="hi", gen_length=8, output_type="invalid") def test_prompt_and_messages_together_raises(self): pipe = _make_dummy_pipeline() with self.assertRaises(ValueError): pipe(prompt="hi", messages=[{"role": "user", "content": "hi"}], gen_length=8, output_type="seq") - def test_prepare_input_ids_from_1d_tensor(self): - pipe = _make_dummy_pipeline() - ids = torch.tensor([1, 2, 3], dtype=torch.long) - result_ids, result_mask, multimodal = pipe._prepare_input_ids( - prompt=None, messages=None, input_ids=ids, attention_mask=None, add_generation_prompt=False - ) - self.assertEqual(result_ids.shape, (1, 3)) - self.assertEqual(result_mask.shape, (1, 3)) - self.assertTrue((result_mask == 1).all().item()) - self.assertEqual(multimodal, {}) - # --- End-to-end generation: the prefill-once path drives the real encoder/decoder, so it needs the tiny model --- @@ -87,11 +76,11 @@ def _load_pipeline(test): class DiffusionGemmaPipelineTest(unittest.TestCase): def setUp(self): self.pipe, self.canvas_length = _load_pipeline(self) - self.messages = [{"role": "user", "content": "Name a color."}] + self.prompt = "Name a color." def test_generate_seq_shape(self): out = self.pipe( - messages=self.messages, + prompt=self.prompt, gen_length=self.canvas_length * 2, num_inference_steps=4, temperature=0.0, @@ -103,7 +92,7 @@ def test_generate_seq_shape(self): def test_generate_text_and_return_tuple(self): sequences, texts = self.pipe( - messages=self.messages, + prompt=self.prompt, gen_length=self.canvas_length, num_inference_steps=4, temperature=0.0, @@ -123,7 +112,7 @@ def callback(pipe, step, timestep, callback_kwargs): keys = list(self.pipe._callback_tensor_inputs) self.pipe( - messages=self.messages, + prompt=self.prompt, gen_length=self.canvas_length, num_inference_steps=2, temperature=0.0, @@ -139,11 +128,9 @@ def test_generate_with_image(self): from PIL import Image image = Image.fromarray((np.random.rand(64, 64, 3) * 255).astype("uint8")) - messages = [ - {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What?"}]} - ] out = self.pipe( - messages=messages, + prompt="What?", + image=image, gen_length=self.canvas_length, num_inference_steps=2, temperature=0.0, @@ -158,7 +145,7 @@ def test_schedulers_are_interchangeable(self): for scheduler in (DiscreteDDIMScheduler(), EntropyBoundScheduler(entropy_bound=0.1)): self.pipe.scheduler = scheduler out = self.pipe( - messages=self.messages, + prompt=self.prompt, gen_length=self.canvas_length, num_inference_steps=4, temperature=0.0, @@ -169,7 +156,7 @@ def test_schedulers_are_interchangeable(self): def test_static_cache_matches_dynamic(self): kwargs = { - "messages": self.messages, + "prompt": self.prompt, "gen_length": self.canvas_length * 2, # two canvases -> exercises the cache extension between blocks "num_inference_steps": 4, "temperature": 0.0, From 3b44324f8972ba74f573d7135172fbe1a577f0d7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 20 Jun 2026 08:06:23 +0000 Subject: [PATCH 16/17] Add leave-one-out predictor-corrector to DiscreteDDIM scheduler Adds optional Gibbs corrector sweeps after each predictor step for uniform diffusion, recovering the LOO denoiser in closed form so it works on the released checkpoint with no retraining. Co-Authored-By: Claude Opus 4.8 --- .../en/api/pipelines/diffusion_gemma.md | 12 ++ .../pipeline_diffusion_gemma.py | 18 +++ .../schedulers/scheduling_discrete_ddim.py | 139 +++++++++++++++++- .../diffusion_gemma/test_diffusion_gemma.py | 14 ++ .../test_scheduler_discrete_ddim.py | 39 +++++ 5 files changed, 221 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md index dc2fb671fad2..ddd5d24cea7f 100644 --- a/docs/source/en/api/pipelines/diffusion_gemma.md +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -76,6 +76,18 @@ from diffusers import BlockRefinementScheduler pipe.scheduler = BlockRefinementScheduler.from_config(pipe.scheduler.config, threshold=0.9) ``` +### Predictor-corrector sampling + +`DiscreteDDIMScheduler` supports the leave-one-out predictor-corrector of [Reparameterizing Uniform Diffusion Models](https://huggingface.co/papers/2605.22765). After each predictor step the pipeline runs `corrector_steps` Gibbs sweeps that resample the least-confident positions from the one-coordinate conditional of the noisy marginal, which leaves that marginal invariant and improves generation at no extra training cost. It works directly on the released checkpoint: for uniform diffusion the denoiser and the leave-one-out posterior are interchangeable in closed form, so the corrector recovers the leave-one-out quantities it needs without any retraining. + +```py +from diffusers import DiscreteDDIMScheduler + +pipe.scheduler = DiscreteDDIMScheduler(corrector_steps=2, corrector_k=12) +output = pipe(prompt="Why is the sky blue?", gen_length=256, num_inference_steps=48) +print(output.texts[0]) +``` + ## Static cache and compilation The pipeline prefills the encoder once per block into a reusable cache (a `DynamicCache` by default). Pass diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index 562a14d95baf..60f8cedfa539 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -353,6 +353,24 @@ def __call__( ) canvas = scheduler_output.prev_sample + # Predictor-corrector (https://huggingface.co/papers/2605.22765): each Gibbs sweep needs fresh logits on + # the updated canvas, so the model is recomputed here. Skipped on the final step. + corrector_steps = ( + self.scheduler.config.corrector_steps if isinstance(self.scheduler, DiscreteDDIMScheduler) else 0 + ) + if corrector_steps and step_idx + 1 < num_inference_steps: + for _ in range(corrector_steps): + corrector_logits = self.model( + decoder_input_ids=canvas, + past_key_values=past_key_values, + self_conditioning_logits=self_conditioning_logits, + decoder_attention_mask=mask_mapping, + decoder_position_ids=decoder_position_ids, + ).logits + canvas = self.scheduler.step_correct( + model_output=corrector_logits, timestep=step_idx, sample=canvas, generator=generator + ).prev_sample + if callback_on_step_end is not None: callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) diff --git a/src/diffusers/schedulers/scheduling_discrete_ddim.py b/src/diffusers/schedulers/scheduling_discrete_ddim.py index 22b02b8c2496..daf0abcf47f0 100644 --- a/src/diffusers/schedulers/scheduling_discrete_ddim.py +++ b/src/diffusers/schedulers/scheduling_discrete_ddim.py @@ -14,6 +14,7 @@ from __future__ import annotations +import math from dataclasses import dataclass import torch @@ -53,15 +54,37 @@ class DiscreteDDIMScheduler(SchedulerMixin, ConfigMixin): or jump to a uniformly random token. Unlike masked diffusion, there is no mask token; uncommitted positions carry random tokens. + An optional predictor-corrector mode follows "Reparameterizing Uniform Diffusion Models" via the leave-one-out + (LOO) denoiser (https://huggingface.co/papers/2605.22765). When `corrector_steps > 0`, the pipeline runs that many + Gibbs corrector sweeps after each predictor step (see [`~DiscreteDDIMScheduler.step_correct`]), resampling the + least-confident positions from the one-coordinate conditional `Cat(alpha_s * x0_loo + (1 - alpha_s) / K)` while + holding the rest fixed, which leaves the marginal `p_s` invariant and improves generation at no training cost. + Args: num_inference_steps (`int`, defaults to 32): The number of denoising steps, defining the linear time grid the posterior is evaluated on. + corrector_steps (`int`, defaults to 0): + Number of Gibbs corrector sweeps run after each predictor step. `0` recovers plain ancestral DDIM sampling. + corrector_k (`int`, defaults to 1): + Number of positions resampled per corrector sweep. + corrector_selection (`str`, defaults to `"lowest_log_margin"`): + How the resampled positions are chosen: `"lowest_log_margin"`, `"lowest_maxprob"`, `"lowest_current_prob"`, + or `"random"`. + corrector_selection_tau (`float`, defaults to 1.0): + Temperature of the Gumbel-top-k position selection (lower is greedier). """ order = 1 @register_to_config - def __init__(self, num_inference_steps: int = 32): + def __init__( + self, + num_inference_steps: int = 32, + corrector_steps: int = 0, + corrector_k: int = 1, + corrector_selection: str = "lowest_log_margin", + corrector_selection_tau: float = 1.0, + ): self.num_inference_steps = num_inference_steps self.timesteps = torch.arange(num_inference_steps, dtype=torch.long) @@ -95,6 +118,26 @@ def _sample_from_logits( token_prob = torch.gather(probs, -1, token) return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + def _alpha(self, step_index: int) -> float: + """Survival probability `alpha = 1 - t` of a clean token at the time grid point `step_index`.""" + return step_index / self.num_inference_steps + + @staticmethod + def _to_loo_logits(logits: torch.Tensor, tokens: torch.LongTensor, alpha: float) -> torch.Tensor: + """ + Convert plain-denoiser logits to the leave-one-out posterior for the uniform kernel. + + Subtracts `log(1 + K * alpha / (1 - alpha))` from the observed token's logit (eq. 13 of + https://huggingface.co/papers/2605.22765); renormalization happens in the following softmax. + """ + if alpha <= 0.0 or alpha >= 1.0: + return logits + delta = math.log1p(logits.shape[-1] * alpha / (1.0 - alpha)) + shifted = logits.clone() + src = torch.full((*tokens.shape, 1), -delta, dtype=shifted.dtype, device=shifted.device) + shifted.scatter_add_(-1, tokens.unsqueeze(-1), src) + return shifted + def step( self, model_output: torch.Tensor, @@ -169,5 +212,99 @@ def step( sampled_probs=sampled_probs, ) + def _select_positions( + self, sample: torch.LongTensor, cond_log_probs: torch.Tensor, generator: torch.Generator | None + ) -> torch.LongTensor: + """Pick `corrector_k` positions per row to resample, least-confident first (Gumbel-top-k without replacement).""" + selection = self.config.corrector_selection + batch_size, seq_len = sample.shape + k_eff = min(max(1, int(self.config.corrector_k)), seq_len) + + if selection == "random": + scores = torch.rand(batch_size, seq_len, device=sample.device, generator=generator) + return torch.topk(scores, k=k_eff, dim=-1).indices + + if selection == "lowest_maxprob": + confidence = -cond_log_probs.max(dim=-1).values + elif selection == "lowest_current_prob": + confidence = -torch.gather(cond_log_probs, -1, sample.unsqueeze(-1)).squeeze(-1) + elif selection == "lowest_log_margin": + log_current = torch.gather(cond_log_probs, -1, sample.unsqueeze(-1)).squeeze(-1) + alt = cond_log_probs.clone().scatter_(-1, sample.unsqueeze(-1), float("-inf")) + confidence = -(log_current - alt.max(dim=-1).values) + else: + raise ValueError(f"Unknown `corrector_selection`: {selection!r}.") + + keys = confidence / float(self.config.corrector_selection_tau) + u = torch.rand(keys.shape, device=keys.device, generator=generator).clamp_(1e-12, 1.0 - 1e-12) + keys = keys + (-torch.log(-torch.log(u))) + return torch.topk(keys, k=k_eff, dim=-1).indices + + def step_correct( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> DiscreteDDIMSchedulerOutput | tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]: + """ + Run one Gibbs corrector sweep at the post-predictor time `s`, following the leave-one-out predictor-corrector + of https://huggingface.co/papers/2605.22765. + + The model logits (recomputed on the current `sample`) are converted to the LOO denoiser, the one-coordinate + conditional `p_s(x^l | x^{-l}) = Cat(alpha_s * x0_loo + (1 - alpha_s) / K)` is formed, the least-confident + `corrector_k` positions are selected, and those positions are resampled while the rest are held fixed. The + sweep preserves `p_s`, so it refines the sample without changing its marginal and needs no extra training. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`): + Raw logits from the model recomputed on the current (post-predictor) `sample`. + timestep (`int` or `torch.Tensor`): + The predictor step index just completed; the corrector runs at the following grid point `s`. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs to refine. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_dict (`bool`): + Whether to return a [`DiscreteDDIMSchedulerOutput`] or a plain tuple. + """ + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + # The corrector acts at the cleaner time `s` reached by the predictor. + alpha_s = self._alpha(step_index + 1) + vocab_size = model_output.shape[-1] + + # Match the reference corrector, which forms the conditional in float64 (the LOO correction reaches ~log(K)). + loo_logits = self._to_loo_logits(model_output.double(), sample, alpha_s) + loo_log_probs = torch.log_softmax(loo_logits, dim=-1) + log_uniform = math.log1p(-alpha_s) - math.log(vocab_size) + cond_log_probs = torch.logaddexp( + math.log(alpha_s) + loo_log_probs, torch.full_like(loo_log_probs, log_uniform) + ) + + positions = self._select_positions(sample, cond_log_probs, generator) + rows = torch.arange(sample.shape[0], device=sample.device).unsqueeze(-1).expand_as(positions) + chosen_probs = cond_log_probs[rows, positions].exp() + resampled = torch.multinomial( + chosen_probs.reshape(-1, vocab_size), num_samples=1, generator=generator + ).view_as(positions) + + prev_sample = sample.clone() + prev_sample[rows, positions] = resampled + sampled_probs = torch.gather(chosen_probs, -1, resampled.unsqueeze(-1)).squeeze(-1) + + if not return_dict: + return prev_sample, resampled, sampled_probs + return DiscreteDDIMSchedulerOutput( + prev_sample=prev_sample, + sampled_tokens=resampled, + sampled_probs=sampled_probs, + ) + __all__ = ["DiscreteDDIMScheduler", "DiscreteDDIMSchedulerOutput"] diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index 76e322c923a6..5bf10f4c4e51 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -154,6 +154,20 @@ def test_schedulers_are_interchangeable(self): ) self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + def test_predictor_corrector_sampling(self): + from diffusers import DiscreteDDIMScheduler + + self.pipe.scheduler = DiscreteDDIMScheduler(corrector_steps=2, corrector_k=2) + out = self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length, + num_inference_steps=4, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + def test_static_cache_matches_dynamic(self): kwargs = { "prompt": self.prompt, diff --git a/tests/schedulers/test_scheduler_discrete_ddim.py b/tests/schedulers/test_scheduler_discrete_ddim.py index 8073fbb227b3..ce5ed47ab13b 100644 --- a/tests/schedulers/test_scheduler_discrete_ddim.py +++ b/tests/schedulers/test_scheduler_discrete_ddim.py @@ -65,3 +65,42 @@ def test_return_tuple(self): out = scheduler.step(logits, timestep=2, sample=sample, return_dict=False) self.assertIsInstance(out, tuple) self.assertEqual(len(out), 3) + + def test_to_loo_only_shifts_observed_token(self): + # The denoiser->LOO conversion moves only the observed token's logit at each position (eq. 13). + scheduler = self.get_scheduler() + sample = torch.randint(0, 100, (2, 16)) + logits = torch.randn(2, 16, 100) + loo = scheduler._to_loo_logits(logits, sample, alpha=0.4) + diff = loo - logits + moved = diff.abs() > 0 + self.assertTrue(torch.equal(moved.sum(dim=-1), torch.ones(2, 16, dtype=torch.long))) + + def test_step_correct_output_shapes(self): + scheduler = self.get_scheduler(corrector_steps=1, corrector_k=4) + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (3, 16)) + logits = torch.randn(3, 16, 100) + out = scheduler.step_correct(logits, timestep=2, sample=sample) + self.assertEqual(out.prev_sample.shape, sample.shape) + self.assertEqual(out.prev_sample.dtype, sample.dtype) + + def test_step_correct_resamples_at_most_k(self): + # A corrector sweep holds all but `corrector_k` positions per row fixed. + k = 3 + scheduler = self.get_scheduler(corrector_steps=1, corrector_k=k) + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (4, 16)) + logits = torch.randn(4, 16, 100) + out = scheduler.step_correct(logits, timestep=2, sample=sample) + changed = (out.prev_sample != sample).sum(dim=-1) + self.assertTrue(torch.all(changed <= k)) + + def test_step_correct_return_tuple(self): + scheduler = self.get_scheduler(corrector_steps=1) + scheduler.set_timesteps(8) + sample = torch.randint(0, 100, (1, 16)) + logits = torch.randn(1, 16, 100) + out = scheduler.step_correct(logits, timestep=2, sample=sample, return_dict=False) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 3) From 2b4f9bf1b4772672d9c42fcf4e6eddf84ce91e46 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 20 Jun 2026 10:03:58 +0000 Subject: [PATCH 17/17] Forward PEFT adapter API on the DiffusionGemma pipeline The denoiser is a Transformers model, so adapters (LoRA, DoRA, ...) load through its native PEFT integration rather than the diffusers LoRA loader. Also dispatch the predictor-corrector by scheduler capability instead of class. Co-Authored-By: Claude Opus 4.8 --- .../en/api/pipelines/diffusion_gemma.md | 15 ++++++ .../pipeline_diffusion_gemma.py | 47 +++++++++++++++++-- .../diffusion_gemma/test_diffusion_gemma.py | 31 ++++++++++++ 3 files changed, 88 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/pipelines/diffusion_gemma.md b/docs/source/en/api/pipelines/diffusion_gemma.md index ddd5d24cea7f..e2c6b0c4178a 100644 --- a/docs/source/en/api/pipelines/diffusion_gemma.md +++ b/docs/source/en/api/pipelines/diffusion_gemma.md @@ -88,6 +88,21 @@ output = pipe(prompt="Why is the sky blue?", gen_length=256, num_inference_steps print(output.texts[0]) ``` +## PEFT adapters + +The denoiser is a 🤗 Transformers model, so adapters are loaded through its native [PEFT](https://huggingface.co/docs/peft) integration rather than the diffusers `load_lora_weights` API. Because that integration is adapter-type-agnostic, the same calls load LoRA, DoRA, or any other PEFT adapter (e.g. the output of TRL's `SFTTrainer`). The pipeline forwards the PEFT API so you can manage adapters from the pipeline directly: + +```py +pipe.load_adapter("path/to/adapter", adapter_name="sft") # LoRA, DoRA, ... +pipe.set_adapter("sft") +output = pipe(prompt="Why is the sky blue?", gen_length=256) + +pipe.disable_adapters() # run the base model +pipe.delete_adapter("sft") +``` + +Adapters stay active and unmerged: DiffusionGemma ties the encoder and decoder base weights, so fusing an adapter into them would corrupt both branches. + ## Static cache and compilation The pipeline prefills the encoder once per block into a reusable cache (a `DynamicCache` by default). Pass diff --git a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py index 60f8cedfa539..af0319147854 100644 --- a/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py +++ b/src/diffusers/pipelines/diffusion_gemma/pipeline_diffusion_gemma.py @@ -96,6 +96,44 @@ def __init__( def num_timesteps(self): return self._num_timesteps + # --- PEFT adapters --- + # + # The denoiser is a 🤗 Transformers model (a `PeftAdapterMixin`), not a diffusers `ModelMixin`, so the diffusers + # `LoraLoaderMixin` (which targets diffusers components and fuses kohya-format LoRA) does not apply. We instead + # forward the model's native, adapter-type-agnostic PEFT API, so LoRA, DoRA, and other PEFT adapters all load the + # same way. Adapters stay active and unmerged: DiffusionGemma ties the encoder and decoder base weights, so fusing + # would corrupt them. + + def load_adapter(self, *args, **kwargs): + """ + Load a PEFT adapter (LoRA, DoRA, ...) into the underlying model. + + Forwards to [`~transformers.integrations.PeftAdapterMixin.load_adapter`]; the first argument is the path or Hub + id of the adapter (a directory with `adapter_config.json` and the adapter weights, e.g. the output of + [`SFTTrainer`]). + """ + return self.model.load_adapter(*args, **kwargs) + + def set_adapter(self, adapter_name: str | list[str]): + """Activate one or more loaded adapters by name.""" + self.model.set_adapter(adapter_name) + + def enable_adapters(self): + """Enable the attached adapters.""" + self.model.enable_adapters() + + def disable_adapters(self): + """Disable all adapters and run the base model.""" + self.model.disable_adapters() + + def delete_adapter(self, adapter_name: str | list[str]): + """Delete one or more loaded adapters.""" + self.model.delete_adapter(adapter_name) + + def active_adapters(self) -> list[str]: + """Names of the currently active adapters.""" + return self.model.active_adapters() if getattr(self.model, "peft_config", None) else [] + # --- Prompt encoding --- def _prepare_inputs( @@ -353,11 +391,10 @@ def __call__( ) canvas = scheduler_output.prev_sample - # Predictor-corrector (https://huggingface.co/papers/2605.22765): each Gibbs sweep needs fresh logits on - # the updated canvas, so the model is recomputed here. Skipped on the final step. - corrector_steps = ( - self.scheduler.config.corrector_steps if isinstance(self.scheduler, DiscreteDDIMScheduler) else 0 - ) + # Predictor-corrector (https://huggingface.co/papers/2605.22765): a scheduler may expose extra corrector + # sweeps via a `corrector_steps` config and a `step_correct` method. Each sweep needs fresh logits on the + # updated canvas, so the model is recomputed here. Skipped on the final step. + corrector_steps = getattr(self.scheduler.config, "corrector_steps", 0) if corrector_steps and step_idx + 1 < num_inference_steps: for _ in range(corrector_steps): corrector_logits = self.model( diff --git a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py index 5bf10f4c4e51..4110063a70eb 100644 --- a/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py +++ b/tests/pipelines/diffusion_gemma/test_diffusion_gemma.py @@ -3,6 +3,7 @@ import torch from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline +from diffusers.utils.testing_utils import require_peft_backend, require_peft_version_greater # --- Lightweight stand-in for input-validation tests that never reach the model --- @@ -168,6 +169,36 @@ def test_predictor_corrector_sampling(self): ) self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + @require_peft_backend + @require_peft_version_greater("0.18.9") + def test_peft_adapter_api(self): + from peft import LoraConfig + + self.assertEqual(self.pipe.active_adapters(), []) + + # The forwarded API is adapter-type-agnostic; LoRA stands in for any PEFT adapter (DoRA, IA3, ...). + self.pipe.model.add_adapter( + LoraConfig(r=4, lora_alpha=8, lora_dropout=0.0, target_modules="all-linear"), + adapter_name="test", + ) + self.pipe.set_adapter("test") + self.assertIn("test", self.pipe.active_adapters()) + + out = self.pipe( + prompt=self.prompt, + gen_length=self.canvas_length, + num_inference_steps=2, + temperature=0.0, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(out.sequences.shape, (1, self.canvas_length)) + + self.pipe.disable_adapters() + self.pipe.enable_adapters() + self.pipe.delete_adapter("test") + self.assertEqual(self.pipe.active_adapters(), []) + def test_static_cache_matches_dynamic(self): kwargs = { "prompt": self.prompt,