Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,27 @@

from abc import ABC, abstractmethod
import json
from typing import Optional, Tuple
from typing import Optional, Tuple, Generic, TypeVar, Type
import jax
from flax import nnx
from maxdiffusion.checkpointing.checkpointing_utils import (
add_sharding_to_struct,
create_orbax_checkpoint_manager,
get_cpu_mesh_and_sharding,
)
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
from ..pipelines.wan.wan_pipeline import WanPipeline
from .. import max_logging, max_utils
import orbax.checkpoint as ocp


WAN_CHECKPOINT = "WAN_CHECKPOINT"


class WanCheckpointer(ABC):
T = TypeVar("T", bound=WanPipeline)


class WanCheckpointer(Generic[T], ABC):
pipeline_class: Optional[Type[T]] = None

def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT):
self.config = config
Expand Down Expand Up @@ -176,16 +177,61 @@ def _pretrained_save_items(pipeline, pretrained_state_sources, pretrained_config
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
raise NotImplementedError

@abstractmethod
def load_diffusers_checkpoint(self):
raise NotImplementedError
def load_diffusers_checkpoint(
self,
vae_only=False,
load_vae=None,
load_text_encoder=None,
load_transformer=None,
load_scheduler=None,
) -> T:
pipeline = self.pipeline_class.from_pretrained(
self.config,
vae_only=vae_only,
load_vae=load_vae,
load_text_encoder=load_text_encoder,
load_transformer=load_transformer,
load_scheduler=load_scheduler,
)
return pipeline

@abstractmethod
def load_checkpoint(
self, step=None
) -> Tuple[
Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int]
]:
self,
step=None,
vae_only=False,
load_vae=None,
load_text_encoder=None,
load_transformer=None,
load_scheduler=None,
) -> Tuple[T, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = self.pipeline_class.from_checkpoint(
self.config,
restored_checkpoint,
vae_only=vae_only,
load_vae=load_vae,
load_text_encoder=load_text_encoder,
load_transformer=load_transformer,
load_scheduler=load_scheduler,
)
opt_state = self._extract_opt_state(restored_checkpoint)
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint(
vae_only=vae_only,
load_vae=load_vae,
load_text_encoder=load_text_encoder,
load_transformer=load_transformer,
load_scheduler=load_scheduler,
)

return pipeline, opt_state, step

@abstractmethod
def _extract_opt_state(self, restored_checkpoint):
raise NotImplementedError

@abstractmethod
Expand Down
24 changes: 6 additions & 18 deletions src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1


class WanCheckpointer2_1(WanCheckpointer):
class WanCheckpointer2_1(WanCheckpointer[WanPipeline2_1]):
pipeline_class = WanPipeline2_1

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
Expand Down Expand Up @@ -58,23 +59,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipeline2_1.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint.wan_state.keys():
opt_state = restored_checkpoint.wan_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step
def _extract_opt_state(self, restored_checkpoint):
if "opt_state" in restored_checkpoint.wan_state.keys():
return restored_checkpoint.wan_state["opt_state"]
return None

def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict):
"""Saves the training state and model configurations."""
Expand Down
29 changes: 8 additions & 21 deletions src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer


class WanCheckpointer2_2(WanCheckpointer):
class WanCheckpointer2_2(WanCheckpointer[WanPipeline2_2]):
pipeline_class = WanPipeline2_2

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
Expand Down Expand Up @@ -79,26 +80,12 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipeline2_2.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint)
# Check for optimizer state in either transformer
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step
def _extract_opt_state(self, restored_checkpoint):
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
return restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
return restored_checkpoint.high_noise_transformer_state["opt_state"]
return None

def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict):
"""Saves the training state and model configurations."""
Expand Down
24 changes: 6 additions & 18 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1


class WanCheckpointerI2V_2_1(WanCheckpointer):
class WanCheckpointerI2V_2_1(WanCheckpointer[WanPipelineI2V_2_1]):
pipeline_class = WanPipelineI2V_2_1

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
Expand Down Expand Up @@ -58,23 +59,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipelineI2V_2_1.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_1, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipelineI2V_2_1.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint.wan_state.keys():
opt_state = restored_checkpoint.wan_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step
def _extract_opt_state(self, restored_checkpoint):
if "opt_state" in restored_checkpoint.wan_state.keys():
return restored_checkpoint.wan_state["opt_state"]
return None

def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_1, train_states: dict):
"""Saves the training state and model configurations."""
Expand Down
29 changes: 8 additions & 21 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer


class WanCheckpointerI2V_2_2(WanCheckpointer):
class WanCheckpointerI2V_2_2(WanCheckpointer[WanPipelineI2V_2_2]):
pipeline_class = WanPipelineI2V_2_2

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
Expand Down Expand Up @@ -79,26 +80,12 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipelineI2V_2_2.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint)
# Check for optimizer state in either transformer
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step
def _extract_opt_state(self, restored_checkpoint):
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
return restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
return restored_checkpoint.high_noise_transformer_state["opt_state"]
return None

def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_2, train_states: dict):
"""Saves the training state and model configurations."""
Expand Down
24 changes: 6 additions & 18 deletions src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1


class WanVaceCheckpointer2_1(WanCheckpointer):
class WanVaceCheckpointer2_1(WanCheckpointer[VaceWanPipeline2_1]):
pipeline_class = VaceWanPipeline2_1

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
Expand Down Expand Up @@ -57,23 +58,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = VaceWanPipeline2_1.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint.wan_state.keys():
opt_state = restored_checkpoint.wan_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step
def _extract_opt_state(self, restored_checkpoint):
if "opt_state" in restored_checkpoint.wan_state.keys():
return restored_checkpoint.wan_state["opt_state"]
return None

def save_checkpoint(self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict):
"""Saves the training state and model configurations."""
Expand Down
Loading
Loading