From 135542b0f372ccbd529e36c4753fef6043735f44 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Tue, 16 Jun 2026 08:40:30 +0000 Subject: [PATCH] Refactor WAN loading logic: granular flags & boilerplate cleanup. Replace 'vae_only' and 'transformer_only' flags in WAN pipelines and checkpointers with granular 'load_*' boolean flags. Centralize repetitive loading and instantiation logic (from_pretrained, from_checkpoint, etc.) into the base classes. --- .../checkpointing/wan_checkpointer.py | 74 ++++- .../checkpointing/wan_checkpointer_2_1.py | 24 +- .../checkpointing/wan_checkpointer_2_2.py | 29 +- .../checkpointing/wan_checkpointer_i2v_2p1.py | 24 +- .../checkpointing/wan_checkpointer_i2v_2p2.py | 29 +- .../wan_vace_checkpointer_2_1.py | 24 +- .../pipelines/wan/wan_pipeline.py | 255 +++++++++++++----- .../pipelines/wan/wan_pipeline_2_1.py | 81 +++--- .../pipelines/wan/wan_pipeline_2_2.py | 73 +++-- .../pipelines/wan/wan_pipeline_animate.py | 35 +-- .../pipelines/wan/wan_pipeline_i2v_2p1.py | 54 ++-- .../pipelines/wan/wan_pipeline_i2v_2p2.py | 73 +++-- .../pipelines/wan/wan_vace_pipeline_2_1.py | 60 ++--- .../tests/wan/wan_checkpointer_test.py | 217 +++++++++++---- .../tests/wan/wan_vace_pipeline_test.py | 218 +++++++++++++++ 15 files changed, 826 insertions(+), 444 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index e02072f9c..6ebb5bab7 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod import json -from typing import Optional, Tuple +from typing import Optional, Tuple, Generic, TypeVar, Type import jax from flax import nnx from maxdiffusion.checkpointing.checkpointing_utils import ( @@ -24,10 +24,7 @@ create_orbax_checkpoint_manager, get_cpu_mesh_and_sharding, ) -from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 -from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 -from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 -from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 +from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils import orbax.checkpoint as ocp @@ -35,7 +32,11 @@ WAN_CHECKPOINT = "WAN_CHECKPOINT" -class WanCheckpointer(ABC): +T = TypeVar("T", bound=WanPipeline) + + +class WanCheckpointer(Generic[T], ABC): + pipeline_class: Optional[Type[T]] = None def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT): self.config = config @@ -176,16 +177,61 @@ def _pretrained_save_items(pipeline, pretrained_state_sources, pretrained_config def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: raise NotImplementedError - @abstractmethod - def load_diffusers_checkpoint(self): - raise NotImplementedError + def load_diffusers_checkpoint( + self, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) -> T: + pipeline = self.pipeline_class.from_pretrained( + self.config, + vae_only=vae_only, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + return pipeline - @abstractmethod def load_checkpoint( - self, step=None - ) -> Tuple[ - Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int] - ]: + self, + step=None, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) -> Tuple[T, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = self.pipeline_class.from_checkpoint( + self.config, + restored_checkpoint, + vae_only=vae_only, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + opt_state = self._extract_opt_state(restored_checkpoint) + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint( + vae_only=vae_only, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + + return pipeline, opt_state, step + + @abstractmethod + def _extract_opt_state(self, restored_checkpoint): raise NotImplementedError @abstractmethod diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py index 9ea7de30d..492058995 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -24,7 +24,8 @@ from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 -class WanCheckpointer2_1(WanCheckpointer): +class WanCheckpointer2_1(WanCheckpointer[WanPipeline2_1]): + pipeline_class = WanPipeline2_1 def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: @@ -58,23 +59,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step - def load_diffusers_checkpoint(self): - pipeline = WanPipeline2_1.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) - if "opt_state" in restored_checkpoint.wan_state.keys(): - opt_state = restored_checkpoint.wan_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step + def _extract_opt_state(self, restored_checkpoint): + if "opt_state" in restored_checkpoint.wan_state.keys(): + return restored_checkpoint.wan_state["opt_state"] + return None def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 6b1e0754e..20b984447 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -24,7 +24,8 @@ from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer -class WanCheckpointer2_2(WanCheckpointer): +class WanCheckpointer2_2(WanCheckpointer[WanPipeline2_2]): + pipeline_class = WanPipeline2_2 def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: @@ -79,26 +80,12 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step - def load_diffusers_checkpoint(self): - pipeline = WanPipeline2_2.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer - if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): - opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] - elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): - opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step + def _extract_opt_state(self, restored_checkpoint): + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + return restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + return restored_checkpoint.high_noise_transformer_state["opt_state"] + return None def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py index ccb10af6e..b35933646 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py @@ -24,7 +24,8 @@ from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 -class WanCheckpointerI2V_2_1(WanCheckpointer): +class WanCheckpointerI2V_2_1(WanCheckpointer[WanPipelineI2V_2_1]): + pipeline_class = WanPipelineI2V_2_1 def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: @@ -58,23 +59,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step - def load_diffusers_checkpoint(self): - pipeline = WanPipelineI2V_2_1.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_1, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipelineI2V_2_1.from_checkpoint(self.config, restored_checkpoint) - if "opt_state" in restored_checkpoint.wan_state.keys(): - opt_state = restored_checkpoint.wan_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step + def _extract_opt_state(self, restored_checkpoint): + if "opt_state" in restored_checkpoint.wan_state.keys(): + return restored_checkpoint.wan_state["opt_state"] + return None def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_1, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py index ce3cc7bb1..943888c83 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py @@ -24,7 +24,8 @@ from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer -class WanCheckpointerI2V_2_2(WanCheckpointer): +class WanCheckpointerI2V_2_2(WanCheckpointer[WanPipelineI2V_2_2]): + pipeline_class = WanPipelineI2V_2_2 def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: @@ -79,26 +80,12 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step - def load_diffusers_checkpoint(self): - pipeline = WanPipelineI2V_2_2.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer - if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): - opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] - elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): - opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step + def _extract_opt_state(self, restored_checkpoint): + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + return restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + return restored_checkpoint.high_noise_transformer_state["opt_state"] + return None def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_2, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py index c5e9d2159..a5a354180 100644 --- a/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py @@ -23,7 +23,8 @@ from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1 -class WanVaceCheckpointer2_1(WanCheckpointer): +class WanVaceCheckpointer2_1(WanCheckpointer[VaceWanPipeline2_1]): + pipeline_class = VaceWanPipeline2_1 def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: @@ -57,23 +58,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step - def load_diffusers_checkpoint(self): - pipeline = VaceWanPipeline2_1.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) - if "opt_state" in restored_checkpoint.wan_state.keys(): - opt_state = restored_checkpoint.wan_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step + def _extract_opt_state(self, restored_checkpoint): + if "opt_state" in restored_checkpoint.wan_state.keys(): + return restored_checkpoint.wan_state["opt_state"] + return None def save_checkpoint(self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 75b57ce91..49a85c852 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -267,6 +267,8 @@ class WanPipeline: Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ + _transformer_keys = ["transformer"] + def __init__( self, tokenizer: AutoTokenizer, @@ -464,6 +466,8 @@ def quantize_transformer( mesh: Mesh, ): """Quantizes the transformer model.""" + if model is None: + return None q_rules = cls.get_qt_provider(config) if not q_rules: return model @@ -726,7 +730,14 @@ def _decode_latents_to_video(self, latents: jax.Array, trace: Optional[dict] = N return video @classmethod - def _create_common_components(cls, config, vae_only=False, i2v=False): + def _create_common_components( + cls, + config, + load_vae=True, + load_text_encoder=True, + load_scheduler=True, + i2v=False, + ): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) @@ -753,18 +764,9 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - with vae_mesh: - wan_vae, vae_cache = cls.load_vae( - devices_array=devices_array, - mesh=vae_mesh, - rngs=rngs, - config=config, - vae_logical_axis_rules=vae_logical_axis_rules, - ) - components = { - "vae": wan_vae, - "vae_cache": vae_cache, + "vae": None, + "vae_cache": None, "devices_array": devices_array, "rngs": rngs, "mesh": mesh, @@ -778,17 +780,144 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): "image_encoder": None, } - if not vae_only: + if load_vae: + max_logging.log("Loading VAE") + components["vae"], components["vae_cache"] = cls.load_vae( + devices_array=devices_array, + mesh=vae_mesh, + rngs=rngs, + config=config, + vae_logical_axis_rules=vae_logical_axis_rules, + ) + + if load_text_encoder: + max_logging.log("Loading Tokenizer and Text Encoder") components["tokenizer"] = cls.load_tokenizer(config=config) components["text_encoder"] = cls.load_text_encoder(config=config) - components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) if cls._needs_image_encoder(config, i2v=i2v): ( components["image_processor"], components["image_encoder"], ) = cls.load_image_encoder(config) + + if load_scheduler: + components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) + return components + @classmethod + @abstractmethod + def _load_and_init( + cls, + config, + restored_checkpoint=None, + load_vae=True, + load_text_encoder=True, + load_transformer=True, + load_scheduler=True, + ): + """Loads and initializes the pipeline components.""" + raise NotImplementedError + + @classmethod + def _resolve_and_validate_load_flags( + cls, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) -> Tuple[bool, bool, bool, bool]: + if vae_only: + if load_vae is False: + raise ValueError("Conflict: vae_only=True but load_vae=False") + if load_text_encoder is True: + raise ValueError("Conflict: vae_only=True but load_text_encoder=True") + if load_transformer is True: + raise ValueError("Conflict: vae_only=True but load_transformer=True") + if load_scheduler is True: + raise ValueError("Conflict: vae_only=True but load_scheduler=True") + return True, False, False, False + + return ( + True if load_vae is None else load_vae, + True if load_text_encoder is None else load_text_encoder, + True if load_transformer is None else load_transformer, + True if load_scheduler is None else load_scheduler, + ) + + @classmethod + def from_pretrained( + cls, + config, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ): + ( + load_vae, + load_text_encoder, + load_transformer, + load_scheduler, + ) = cls._resolve_and_validate_load_flags( + vae_only=vae_only, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + outputs = cls._load_and_init( + config, + None, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + pipeline = outputs[0] + loaded_transformers = outputs[1:] + + for key, transformer in zip(cls._transformer_keys, loaded_transformers): + quantized = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + setattr(pipeline, key, quantized) + + return pipeline + + @classmethod + def from_checkpoint( + cls, + config, + restored_checkpoint=None, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ): + ( + load_vae, + load_text_encoder, + load_transformer, + load_scheduler, + ) = cls._resolve_and_validate_load_flags( + vae_only=vae_only, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + outputs = cls._load_and_init( + config, + restored_checkpoint=restored_checkpoint, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + return outputs[0] + @classmethod def _needs_image_encoder(cls, config: HyperParameters, i2v: bool = False) -> bool: return i2v and config.model_name == "wan2.1" @@ -878,68 +1007,66 @@ def _prepare_model_inputs( latents: jax.Array = None, prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, ): if max_sequence_length is None: max_sequence_length = getattr(self.config, "max_sequence_length", 512) - if not vae_only: - if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) + if num_frames % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - prompt = [prompt] + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + prompt = [prompt] - batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt + batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt - with jax.named_scope("Encode-Prompt"): - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) + with jax.named_scope("Encode-Prompt"): + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) - num_channel_latents = self._get_num_channel_latents() - if latents is None: - latents = self.prepare_latents( - batch_size=batch_size, - vae_scale_factor_temporal=self.vae_scale_factor_temporal, - vae_scale_factor_spatial=self.vae_scale_factor_spatial, - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channel_latents, - ) + num_channel_latents = self._get_num_channel_latents() + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + vae_scale_factor_temporal=self.vae_scale_factor_temporal, + vae_scale_factor_spatial=self.vae_scale_factor_spatial, + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channel_latents, + ) - data_sharding = NamedSharding(self.mesh, P()) - # Using global_batch_size_to_train_on so not to create more config variables - if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = NamedSharding(self.mesh, P()) + # Using global_batch_size_to_train_on so not to create more config variables + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - latents = jax.device_put(latents, data_sharding) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + latents = jax.device_put(latents, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) - scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, - num_inference_steps=num_inference_steps, - shape=latents.shape, - ) + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, + num_inference_steps=num_inference_steps, + shape=latents.shape, + ) - return ( - latents, - prompt_embeds, - negative_prompt_embeds, - scheduler_state, - num_frames, - ) + return ( + latents, + prompt_embeds, + negative_prompt_embeds, + scheduler_state, + num_frames, + ) @abstractmethod def __call__(self, **kwargs): diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index ff881469c..85daec33a 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -35,53 +35,48 @@ def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **k self.transformer = transformer @classmethod - def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): - common_components = cls._create_common_components(config, vae_only) - transformer = None - if not vae_only: - if load_transformer: - transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer", - ) - - pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - transformer=transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - vae_mesh=common_components["vae_mesh"], - vae_logical_axis_rules=common_components["vae_logical_axis_rules"], - config=config, - ) - - return pipeline, transformer - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) - pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint( + def _load_and_init( cls, config: HyperParameters, restored_checkpoint=None, - vae_only=False, + load_vae=True, + load_text_encoder=True, load_transformer=True, + load_scheduler=True, ): - pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) - return pipeline + common_components = cls._create_common_components( + config, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_scheduler=load_scheduler, + ) + transformer = None + if load_transformer: + transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + vae_mesh=common_components["vae_mesh"], + vae_logical_axis_rules=common_components["vae_logical_axis_rules"], + config=config, + ) + + return pipeline, transformer def _get_num_channel_latents(self) -> int: return self.transformer.config.in_channels @@ -100,7 +95,6 @@ def __call__( latents: Optional[jax.Array] = None, prompt_embeds: Optional[jax.Array] = None, negative_prompt_embeds: Optional[jax.Array] = None, - vae_only: bool = False, use_cfg_cache: bool = False, use_magcache: bool = False, magcache_thresh: Optional[float] = None, @@ -138,7 +132,6 @@ def __call__( latents, prompt_embeds, negative_prompt_embeds, - vae_only, ) latents.block_until_ready() prompt_embeds.block_until_ready() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 43f588c46..00d11f961 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -31,6 +31,8 @@ class WanPipeline2_2(WanPipeline): """Pipeline for WAN 2.2 with dual transformers.""" + _transformer_keys = ["low_noise_transformer", "high_noise_transformer"] + def __init__( self, config: HyperParameters, @@ -44,10 +46,23 @@ def __init__( self.boundary_ratio = config.boundary_ratio @classmethod - def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): - common_components = cls._create_common_components(config, vae_only) + def _load_and_init( + cls, + config, + restored_checkpoint=None, + load_vae=True, + load_text_encoder=True, + load_transformer=True, + load_scheduler=True, + ): + common_components = cls._create_common_components( + config, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_scheduler=load_scheduler, + ) low_noise_transformer, high_noise_transformer = None, None - if not vae_only and load_transformer: + if load_transformer: low_noise_transformer = super().load_transformer( devices_array=common_components["devices_array"], mesh=common_components["mesh"], @@ -65,42 +80,22 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t subfolder="transformer", ) - pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - vae_mesh=common_components["vae_mesh"], - vae_logical_axis_rules=common_components["vae_logical_axis_rules"], - config=config, - ) - return pipeline, low_noise_transformer, high_noise_transformer - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) - pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) - pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint( - cls, - config: HyperParameters, - restored_checkpoint=None, - vae_only=False, - load_transformer=True, - ): - pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init( - config, restored_checkpoint, vae_only, load_transformer + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + vae_mesh=common_components["vae_mesh"], + vae_logical_axis_rules=common_components["vae_logical_axis_rules"], + config=config, ) - return pipeline + return pipeline, low_noise_transformer, high_noise_transformer def _get_num_channel_latents(self) -> int: return self.low_noise_transformer.config.in_channels @@ -120,7 +115,6 @@ def __call__( latents: jax.Array = None, prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, use_cfg_cache: bool = False, use_sen_cache: bool = False, use_kv_cache: bool = False, @@ -167,7 +161,6 @@ def __call__( latents, prompt_embeds, negative_prompt_embeds, - vae_only, ) latents.block_until_ready() prompt_embeds.block_until_ready() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py index 69fc70b02..ffd87ddfd 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py @@ -277,12 +277,19 @@ def _load_and_init( cls, config: HyperParameters, restored_checkpoint=None, - vae_only: bool = False, + load_vae: bool = True, + load_text_encoder: bool = True, load_transformer: bool = True, + load_scheduler: bool = True, ) -> Tuple["WanAnimatePipeline", Optional[WanAnimateTransformer3DModel]]: - common_components = cls._create_common_components(config, vae_only) + common_components = cls._create_common_components( + config, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_scheduler=load_scheduler, + ) transformer = None - if not vae_only and load_transformer: + if load_transformer: transformer = cls.load_animate_transformer( devices_array=common_components["devices_array"], mesh=common_components["mesh"], @@ -309,28 +316,6 @@ def _load_and_init( ) return pipeline, transformer - @classmethod - def from_pretrained( - cls, - config: HyperParameters, - vae_only: bool = False, - load_transformer: bool = True, - ) -> "WanAnimatePipeline": - pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) - pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint( - cls, - config: HyperParameters, - restored_checkpoint=None, - vae_only: bool = False, - load_transformer: bool = True, - ) -> "WanAnimatePipeline": - pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) - return pipeline - # ------------------------------------------------------------------ # Abstract method implementation # ------------------------------------------------------------------ diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 4902bcc3e..3bfcc7513 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -38,19 +38,32 @@ def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **k self.transformer = transformer @classmethod - def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): - common_components = cls._create_common_components(config, vae_only, i2v=True) + def _load_and_init( + cls, + config, + restored_checkpoint=None, + load_vae=True, + load_text_encoder=True, + load_transformer=True, + load_scheduler=True, + ): + common_components = cls._create_common_components( + config, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_scheduler=load_scheduler, + i2v=True, + ) transformer = None - if not vae_only: - if load_transformer: - transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer", - ) + if load_transformer: + transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) pipeline = cls( tokenizer=common_components["tokenizer"], @@ -70,23 +83,6 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t ) return pipeline, transformer - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) - pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint( - cls, - config: HyperParameters, - restored_checkpoint=None, - vae_only=False, - load_transformer=True, - ): - pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) - return pipeline - def prepare_latents( self, image: jax.Array, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 568398698..f071c231f 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -33,6 +33,8 @@ class WanPipelineI2V_2_2(WanPipeline): """Pipeline for WAN 2.2 Image-to-Video.""" + _transformer_keys = ["low_noise_transformer", "high_noise_transformer"] + def __init__( self, config: HyperParameters, @@ -46,27 +48,40 @@ def __init__( self.boundary_ratio = config.boundary_ratio @classmethod - def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): - common_components = cls._create_common_components(config, vae_only, i2v=True) + def _load_and_init( + cls, + config, + restored_checkpoint=None, + load_vae=True, + load_text_encoder=True, + load_transformer=True, + load_scheduler=True, + ): + common_components = cls._create_common_components( + config, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_scheduler=load_scheduler, + i2v=True, + ) low_noise_transformer, high_noise_transformer = None, None - if not vae_only: - if load_transformer: - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer", - ) - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer_2", - ) + if load_transformer: + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2", + ) pipeline = cls( tokenizer=common_components["tokenizer"], @@ -87,24 +102,6 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t ) return pipeline, low_noise_transformer, high_noise_transformer - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) - pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) - pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint( - cls, - config: HyperParameters, - restored_checkpoint=None, - vae_only=False, - load_transformer=True, - ): - pipeline, _, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) - return pipeline - def prepare_latents( self, image: jax.Array, diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index f561da56f..dcdf9396d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -351,22 +351,28 @@ def _load_and_init( cls, config: HyperParameters, restored_checkpoint=None, - vae_only=False, + load_vae=True, + load_text_encoder=True, load_transformer=True, + load_scheduler=True, ): - common_components = cls._create_common_components(config, vae_only) + common_components = cls._create_common_components( + config, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_scheduler=load_scheduler, + ) transformer = None - if not vae_only: - if load_transformer: - transformer = cls.load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer", - ) + if load_transformer: + transformer = cls.load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) pipeline = cls( tokenizer=common_components["tokenizer"], @@ -383,35 +389,7 @@ def _load_and_init( config=config, ) - return pipeline - - @classmethod - def from_pretrained( - cls, - config: HyperParameters, - vae_only=False, - load_transformer=True, - ): - pipeline = cls._load_and_init(config, None, vae_only, load_transformer) - pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, pipeline.mesh) - return pipeline - - @classmethod - def from_checkpoint( - cls, - config: HyperParameters, - restored_checkpoint=None, - vae_only=False, - load_transformer=True, - ): - pipeline = cls._load_and_init( - config, - restored_checkpoint, - vae_only, - load_transformer, - ) - pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, pipeline.mesh) - return pipeline + return pipeline, transformer def check_inputs( self, diff --git a/src/maxdiffusion/tests/wan/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan/wan_checkpointer_test.py index b18b0df7e..b30006edf 100644 --- a/src/maxdiffusion/tests/wan/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan/wan_checkpointer_test.py @@ -19,7 +19,10 @@ from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1 from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 from maxdiffusion.pipelines.wan.wan_pipeline import _select_restored_transformer_state +from maxdiffusion.pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 +from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 +from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 class WanPretrainedCacheTest(unittest.TestCase): @@ -166,27 +169,34 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") - def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_1, "from_pretrained", autospec=True) + def test_load_from_diffusers(self, mock_from_pretrained, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = None mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + mock_from_pretrained.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() - mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + mock_from_pretrained.assert_called_once_with( + self.config, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") - def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_1, "from_checkpoint", autospec=True) + def test_load_checkpoint_no_optimizer(self, mock_from_checkpoint, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 metadata_mock = MagicMock() @@ -203,20 +213,28 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + mock_from_checkpoint.assert_called_with( + self.config, + mock_manager.restore.return_value, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") - def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_1, "from_checkpoint", autospec=True) + def test_load_checkpoint_with_optimizer(self, mock_from_checkpoint, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 metadata_mock = MagicMock() @@ -233,13 +251,21 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + mock_from_checkpoint.assert_called_with( + self.config, + mock_manager.restore.return_value, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) self.assertEqual(opt_state["learning_rate"], 0.001) @@ -255,28 +281,35 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") - def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_2, "from_pretrained", autospec=True) + def test_load_from_diffusers(self, mock_from_pretrained, mock_create_manager): """Test loading from pretrained when no checkpoint exists.""" mock_manager = MagicMock() mock_manager.latest_step.return_value = None mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + mock_from_pretrained.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() - mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + mock_from_pretrained.assert_called_once_with( + self.config, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") - def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_2, "from_checkpoint", autospec=True) + def test_load_checkpoint_no_optimizer(self, mock_from_checkpoint, mock_create_manager): """Test loading checkpoint without optimizer state.""" mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -296,20 +329,28 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + mock_from_checkpoint.assert_called_with( + self.config, + mock_manager.restore.return_value, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") - def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_2, "from_checkpoint", autospec=True) + def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_from_checkpoint, mock_create_manager): """Test loading checkpoint with optimizer state in low_noise_transformer.""" mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -329,21 +370,29 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mo mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + mock_from_checkpoint.assert_called_with( + self.config, + mock_manager.restore.return_value, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) self.assertEqual(opt_state["learning_rate"], 0.001) self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") - def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_2, "from_checkpoint", autospec=True) + def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_from_checkpoint, mock_create_manager): """Test loading checkpoint with optimizer state in high_noise_transformer.""" mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -363,13 +412,21 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + mock_from_checkpoint.assert_called_with( + self.config, + mock_manager.restore.return_value, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) self.assertEqual(opt_state["learning_rate"], 0.002) @@ -399,7 +456,14 @@ def test_load_from_diffusers(self, mock_from_pretrained, mock_create_manager): pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() - mock_from_pretrained.assert_called_once_with(self.config) + mock_from_pretrained.assert_called_once_with( + self.config, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertIsNone(step) @@ -428,7 +492,15 @@ def test_load_checkpoint_no_optimizer(self, mock_from_checkpoint, mock_create_ma pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once() - mock_from_checkpoint.assert_called_once_with(self.config, restored_mock) + mock_from_checkpoint.assert_called_once_with( + self.config, + restored_mock, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertEqual(step, 1) @@ -457,7 +529,15 @@ def test_load_checkpoint_with_optimizer(self, mock_from_checkpoint, mock_create_ pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once() - mock_from_checkpoint.assert_called_once_with(self.config, restored_mock) + mock_from_checkpoint.assert_called_once_with( + self.config, + restored_mock, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) self.assertEqual(opt_state["learning_rate"], 0.001) @@ -474,27 +554,34 @@ def setUp(self): self.config.model_type = "I2V" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") - def test_load_from_diffusers(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + @patch.object(WanPipelineI2V_2_2, "from_pretrained", autospec=True) + def test_load_from_diffusers(self, mock_from_pretrained, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = None mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline_i2v_2p2.from_pretrained.return_value = mock_pipeline_instance + mock_from_pretrained.return_value = mock_pipeline_instance checkpointer = WanCheckpointerI2V_2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() - mock_wan_pipeline_i2v_2p2.from_pretrained.assert_called_once_with(self.config) + mock_from_pretrained.assert_called_once_with( + self.config, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") - def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + @patch.object(WanPipelineI2V_2_2, "from_checkpoint", autospec=True) + def test_load_checkpoint_no_optimizer(self, mock_from_checkpoint, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 metadata_mock = MagicMock() @@ -512,20 +599,28 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline_i2v_2p2, mock_crea mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline_i2v_2p2.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointerI2V_2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once() - mock_wan_pipeline_i2v_2p2.from_checkpoint.assert_called_once_with(self.config, restored_mock) + mock_from_checkpoint.assert_called_once_with( + self.config, + restored_mock, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNone(opt_state) self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") - def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + @patch.object(WanPipelineI2V_2_2, "from_checkpoint", autospec=True) + def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_from_checkpoint, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 metadata_mock = MagicMock() @@ -543,21 +638,29 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline_i2v mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline_i2v_2p2.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointerI2V_2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once() - mock_wan_pipeline_i2v_2p2.from_checkpoint.assert_called_once_with(self.config, restored_mock) + mock_from_checkpoint.assert_called_once_with( + self.config, + restored_mock, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) self.assertEqual(opt_state["learning_rate"], 0.001) self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") - def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + @patch.object(WanPipelineI2V_2_2, "from_checkpoint", autospec=True) + def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_from_checkpoint, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 metadata_mock = MagicMock() @@ -575,13 +678,21 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline_i2 mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline_i2v_2p2.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointerI2V_2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once() - mock_wan_pipeline_i2v_2p2.from_checkpoint.assert_called_once_with(self.config, restored_mock) + mock_from_checkpoint.assert_called_once_with( + self.config, + restored_mock, + vae_only=False, + load_vae=None, + load_text_encoder=None, + load_transformer=None, + load_scheduler=None, + ) self.assertEqual(pipeline, mock_pipeline_instance) self.assertIsNotNone(opt_state) self.assertEqual(opt_state["learning_rate"], 0.002) @@ -597,8 +708,8 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_1.WanPipeline2_1") - def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_1, "from_checkpoint", autospec=True) + def test_load_checkpoint_with_explicit_none_step(self, mock_from_checkpoint, mock_create_manager): """Test loading checkpoint with explicit None step falls back to latest.""" mock_manager = MagicMock() mock_manager.latest_step.return_value = 5 @@ -614,7 +725,7 @@ def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_c mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) @@ -623,8 +734,8 @@ def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_c self.assertEqual(step, 5) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer_2_2.WanPipeline2_2") - def test_load_checkpoint_both_optimizers_present(self, mock_wan_pipeline, mock_create_manager): + @patch.object(WanPipeline2_2, "from_checkpoint", autospec=True) + def test_load_checkpoint_both_optimizers_present(self, mock_from_checkpoint, mock_create_manager): """Test loading checkpoint when both transformers have optimizer state (prioritize low_noise).""" mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -642,7 +753,7 @@ def test_load_checkpoint_both_optimizers_present(self, mock_wan_pipeline, mock_c mock_create_manager.return_value = mock_manager mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + mock_from_checkpoint.return_value = mock_pipeline_instance checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) diff --git a/src/maxdiffusion/tests/wan/wan_vace_pipeline_test.py b/src/maxdiffusion/tests/wan/wan_vace_pipeline_test.py index 877c068a1..38acaddb3 100644 --- a/src/maxdiffusion/tests/wan/wan_vace_pipeline_test.py +++ b/src/maxdiffusion/tests/wan/wan_vace_pipeline_test.py @@ -277,6 +277,224 @@ def mock_load_scheduler(config): self.assertEqual(len(video), batch_size) self.assertEqual(video[0].shape, (num_frames, height, width, 3)) + @patch("maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1.WanVACEModel.load_config") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.AutoencoderKLWan.load_config") + @patch("maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1.load_wan_transformer") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.load_wan_vae") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_tokenizer") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_text_encoder") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_scheduler") + # pylint: disable=too-many-positional-arguments + def test_pipeline_load_without_vae_and_text_encoder( + self, + mock_load_scheduler_fn, + mock_load_text_encoder_fn, + mock_load_tokenizer_fn, + mock_load_wan_vae_fn, + mock_load_wan_transformer_fn, + mock_vae_load_config_fn, + mock_transformer_load_config_fn, + ): + def mock_transformer_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs): + config_dict = { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "image_dim": None, + "in_channels": 16, + "num_attention_heads": 12, + "num_layers": 2, + "out_channels": 16, + "patch_size": [1, 2, 2], + "pos_embed_seq_len": None, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 1024, + "text_dim": 4096, + "vace_in_channels": 96, + "vace_layers": [0, 1], + } + if return_unused_kwargs: + return config_dict, kwargs + return config_dict + + mock_transformer_load_config_fn.side_effect = mock_transformer_load_config + + def mock_load_wan_transformer(pretrained_model_name_or_path, eval_shapes, *args, **kwargs): + cpu = jax.local_devices(backend="cpu")[0] + flat_shapes = flax.traverse_util.flatten_dict(eval_shapes) + flat_params = {} + key = jax.random.key(42) + for k, shape_struct in flat_shapes.items(): + dtype = shape_struct.dtype + shape = shape_struct.shape + key, subkey = jax.random.split(key) + val = jax.random.normal(subkey, shape, dtype=dtype) + flat_params[k] = jax.device_put(val, device=cpu) + return flax.traverse_util.unflatten_dict(flat_params) + + mock_load_wan_transformer_fn.side_effect = mock_load_wan_transformer + + def mock_load_scheduler(config): + scheduler = FlaxUniPCMultistepScheduler.from_config({ + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "disable_corrector": [], + "dynamic_thresholding_ratio": 0.995, + "final_sigmas_type": "zero", + "flow_shift": config.flow_shift, + "lower_order_final": True, + "num_train_timesteps": 1000, + "predict_x0": True, + "prediction_type": "flow_prediction", + "rescale_zero_terminal_snr": False, + "sample_max_value": 1.0, + "solver_order": 2, + "solver_p": None, + "solver_type": "bh2", + "steps_offset": 0, + "thresholding": False, + "timestep_spacing": "linspace", + "trained_betas": None, + "use_beta_sigmas": False, + "use_exponential_sigmas": False, + "use_flow_sigmas": True, + "use_karras_sigmas": False, + }) + state = scheduler.create_state() + return scheduler, state + + mock_load_scheduler_fn.side_effect = mock_load_scheduler + + # VAE config mock + def mock_vae_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs): + config_dict = { + "attn_scales": [], + "base_dim": 96, + "dim_mult": [1, 2, 4, 4], + "dropout": 0.0, + "latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + "latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916, + ], + "num_res_blocks": 2, + "temperal_downsample": [False, True, True], + "z_dim": 16, + } + if return_unused_kwargs: + return config_dict, kwargs + return config_dict + + mock_vae_load_config_fn.side_effect = mock_vae_load_config + + def mock_load_wan_vae(pretrained_model_name_or_path, eval_shapes, *args, **kwargs): + cpu = jax.local_devices(backend="cpu")[0] + flat_shapes = flax.traverse_util.flatten_dict(eval_shapes) + flat_params = {} + key = jax.random.key(42) + for k, shape_struct in flat_shapes.items(): + dtype = shape_struct.dtype + shape = shape_struct.shape + key, subkey = jax.random.split(key) + val = jax.random.normal(subkey, shape, dtype=dtype) + flat_params[k] = jax.device_put(val, device=cpu) + return flax.traverse_util.unflatten_dict(flat_params) + + mock_load_wan_vae_fn.side_effect = mock_load_wan_vae + + def run_scenario(load_vae, load_text_encoder, load_transformer, load_scheduler): + mock_load_wan_vae_fn.reset_mock() + mock_load_text_encoder_fn.reset_mock() + mock_load_tokenizer_fn.reset_mock() + mock_load_scheduler_fn.reset_mock() + mock_load_wan_transformer_fn.reset_mock() + + pipeline = VaceWanPipeline2_1.from_pretrained( + self.config, + load_vae=load_vae, + load_text_encoder=load_text_encoder, + load_transformer=load_transformer, + load_scheduler=load_scheduler, + ) + + if load_vae: + self.assertIsNotNone(pipeline.vae) + mock_load_wan_vae_fn.assert_called_once() + else: + self.assertIsNone(pipeline.vae) + mock_load_wan_vae_fn.assert_not_called() + + if load_text_encoder: + self.assertIsNotNone(pipeline.text_encoder) + self.assertIsNotNone(pipeline.tokenizer) + mock_load_text_encoder_fn.assert_called_once() + mock_load_tokenizer_fn.assert_called_once() + else: + self.assertIsNone(pipeline.text_encoder) + self.assertIsNone(pipeline.tokenizer) + mock_load_text_encoder_fn.assert_not_called() + mock_load_tokenizer_fn.assert_not_called() + + if load_transformer: + self.assertIsNotNone(pipeline.transformer) + mock_load_wan_transformer_fn.assert_called_once() + else: + self.assertIsNone(pipeline.transformer) + mock_load_wan_transformer_fn.assert_not_called() + + if load_scheduler: + self.assertIsNotNone(pipeline.scheduler) + mock_load_scheduler_fn.assert_called_once() + else: + self.assertIsNone(pipeline.scheduler) + mock_load_scheduler_fn.assert_not_called() + + # Scenario 1: Only transformer + run_scenario(load_vae=False, load_text_encoder=False, load_transformer=True, load_scheduler=False) + # Scenario 2: Only VAE + run_scenario(load_vae=True, load_text_encoder=False, load_transformer=False, load_scheduler=False) + # Scenario 3: Only text encoder + run_scenario(load_vae=False, load_text_encoder=True, load_transformer=False, load_scheduler=False) + # Scenario 4: Only scheduler + run_scenario(load_vae=False, load_text_encoder=False, load_transformer=False, load_scheduler=True) + # Scenario 5: All components + run_scenario(load_vae=True, load_text_encoder=True, load_transformer=True, load_scheduler=True) + if __name__ == "__main__": unittest.main()