From 36e0f5dbdd8d38a8d479795289ff3acad2c82b78 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Thu, 18 Jun 2026 18:23:56 +0000 Subject: [PATCH] LTX2.3 improvements and bug fixes --- src/maxdiffusion/generate_ltx2.py | 3 +- src/maxdiffusion/models/attention_flax.py | 4 +- .../models/ltx2/attention_ltx2.py | 17 +- .../text_encoders/torchax_text_encoder.py | 10 +- .../models/ltx2/transformer_ltx2.py | 303 ++++++++++-------- .../pipelines/ltx2/ltx2_pipeline.py | 130 ++++---- .../tests/ltx2/test_transformer_ltx2.py | 5 +- 7 files changed, 249 insertions(+), 223 deletions(-) diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py index 0445913ef..8d77a2d95 100644 --- a/src/maxdiffusion/generate_ltx2.py +++ b/src/maxdiffusion/generate_ltx2.py @@ -237,7 +237,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): # Export videos for i in range(len(videos)): - video_path = f"{filename_prefix}ltx2_output_{getattr(config, 'seed', 0)}_{i}.mp4" + model_name_prefix = getattr(config, "model_name", "ltx2").replace(".", "_") + video_path = f"{filename_prefix}{model_name_prefix}_output_{getattr(config, 'seed', 0)}_{i}.mp4" audio_i = audios[i] if audios is not None else None audio_format = getattr(config, "audio_format", "s16") diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index edc9f4f7b..fe1356d48 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -1383,7 +1383,7 @@ def __init__( self, mesh: Mesh, attention_kernel: str, - scale: int, + scale: float, heads: int, dim_head: int, use_memory_efficient_attention: bool = False, @@ -1475,7 +1475,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str - scale: int + scale: float heads: int dim_head: int use_memory_efficient_attention: bool = False diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index 214690c98..a061c404f 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -88,13 +88,8 @@ def apply_split_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array: first_x = split_x[..., 0, :] second_x = split_x[..., 1, :] - cos_u = jnp.expand_dims(cos, axis=-2) - sin_u = jnp.expand_dims(sin, axis=-2) - - out = split_x * cos_u - - out_first = out[..., 0, :] - second_x * sin_u.squeeze(-2) - out_second = out[..., 1, :] + first_x * sin_u.squeeze(-2) + out_first = first_x * cos - second_x * sin + out_second = second_x * cos + first_x * sin out = jnp.stack([out_first, out_second], axis=-2) out = out.reshape(*out.shape[:-2], last_dim) @@ -176,12 +171,6 @@ def prepare_video_coords( patch_ends = grid + patch_size_delta # Combine start and end coordinates - latent_coords = jnp.stack([grid, patch_ends], axis=-1) # [3, N_F, N_H, N_W, 2] - latent_coords = latent_coords.transpose(1, 2, 3, 0, 4) # [N_F, N_H, N_W, 3, 2] - latent_coords = latent_coords.reshape(-1, 3, 2) # [num_patches, 3, 2] - latent_coords = jnp.expand_dims(latent_coords, 0) # [1, num_patches, 3, 2] - latent_coords = jnp.tile(latent_coords, (batch_size, 1, 1, 1)) # [B, num_patches, 3, 2] - latent_coords = jnp.stack([grid, patch_ends], axis=-1) # [3, N_F, N_H, N_W, 2] latent_coords = latent_coords.reshape(3, -1, 2) # [3, num_patches, 2] latent_coords = jnp.expand_dims(latent_coords, 0) # [1, 3, num_patches, 2] @@ -485,7 +474,7 @@ def __call__( # 3. Apply RoPE with jax.named_scope("Apply RoPE"): if rotary_emb is not None: - if hasattr(self, "rope_type") and self.rope_type == "split": + if self.rope_type == "split": # Split RoPE: passing full freqs [B, H, S, D//2] # apply_split_rotary_emb handles reshaping query/key diff --git a/src/maxdiffusion/models/ltx2/text_encoders/torchax_text_encoder.py b/src/maxdiffusion/models/ltx2/text_encoders/torchax_text_encoder.py index 882c65a25..456e6d9e3 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/torchax_text_encoder.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/torchax_text_encoder.py @@ -21,9 +21,7 @@ from torchax import interop, default_env # --- Monkeypatch transformers masking_utils to avoid torchax integer tracing bug --- -import transformers.masking_utils - -_orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay +from unittest import mock def _patched_sliding_window_overlay(sliding_window: int): @@ -57,8 +55,7 @@ def __call__( self, input_ids: jax.Array, attention_mask: jax.Array, output_hidden_states: bool = True ) -> Tuple[jax.Array, ...]: # Dynamically patch transformers.masking_utils only during the duration of this call - transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay - try: + with mock.patch("transformers.masking_utils.sliding_window_overlay", _patched_sliding_window_overlay): with default_env(): input_ids = interop.torch_view(input_ids) attention_mask = interop.torch_view(attention_mask) @@ -72,9 +69,6 @@ def __call__( output_hidden_states=output_hidden_states, ) return interop.jax_view(output) - finally: - # Restore original behavior to prevent side effects on other potential models in same env - transformers.masking_utils.sliding_window_overlay = _orig_sliding_window_overlay @staticmethod def _forward_inner(model, input_ids, attention_mask, output_hidden_states=True): diff --git a/src/maxdiffusion/models/ltx2/transformer_ltx2.py b/src/maxdiffusion/models/ltx2/transformer_ltx2.py index f2511ed07..341290f30 100644 --- a/src/maxdiffusion/models/ltx2/transformer_ltx2.py +++ b/src/maxdiffusion/models/ltx2/transformer_ltx2.py @@ -27,6 +27,34 @@ from maxdiffusion.configuration_utils import ConfigMixin, register_to_config from maxdiffusion.common_types import BlockSizes from .logical_sharding_ltx2 import get_sharding_specs, LTX2DiTShardingSpecs +import dataclasses +from flax import struct + + +@struct.dataclass +class LTX2BlockContext: + hidden_states: jax.Array + audio_hidden_states: jax.Array + encoder_hidden_states: jax.Array + audio_encoder_hidden_states: jax.Array + temb: jax.Array + temb_audio: jax.Array + temb_ca_scale_shift: jax.Array + temb_ca_audio_scale_shift: jax.Array + temb_ca_gate: jax.Array + temb_ca_audio_gate: jax.Array + temb_prompt: Optional[jax.Array] = None + temb_prompt_audio: Optional[jax.Array] = None + modality_mask: Optional[jax.Array] = None + video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None + audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None + ca_video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None + ca_audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None + encoder_attention_mask: Optional[jax.Array] = None + audio_encoder_attention_mask: Optional[jax.Array] = None + a2v_cross_attention_mask: Optional[jax.Array] = None + v2a_cross_attention_mask: Optional[jax.Array] = None + perturbation_mask: Optional[jax.Array] = None class LTX2AdaLayerNormSingle(nnx.Module): @@ -279,7 +307,7 @@ def __init__( attention_kernel=a2v_attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, - flash_min_seq_length=0, + flash_min_seq_length=flash_min_seq_length, sharding_specs=self.sharding_specs, gated_attn=gated_attn, ) @@ -399,31 +427,49 @@ def __init__( def __call__( self, - hidden_states: jax.Array, # Video - audio_hidden_states: jax.Array, # Audio - encoder_hidden_states: jax.Array, # Context (Text) - audio_encoder_hidden_states: jax.Array, # Audio Context - # Timestep embeddings for AdaLN - temb: jax.Array, - temb_audio: jax.Array, - temb_ca_scale_shift: jax.Array, - temb_ca_audio_scale_shift: jax.Array, - temb_ca_gate: jax.Array, - temb_ca_audio_gate: jax.Array, - temb_prompt: Optional[jax.Array] = None, - temb_prompt_audio: Optional[jax.Array] = None, - modality_mask: Optional[jax.Array] = None, - # RoPE - video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, - audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, - ca_video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, - ca_audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, - encoder_attention_mask: Optional[jax.Array] = None, - audio_encoder_attention_mask: Optional[jax.Array] = None, - a2v_cross_attention_mask: Optional[jax.Array] = None, - v2a_cross_attention_mask: Optional[jax.Array] = None, - perturbation_mask: Optional[jax.Array] = None, + ctx: "LTX2BlockContext", ) -> Tuple[jax.Array, jax.Array]: + """ + Forward pass of the LTX2 video/audio transformer block. + + This block handles complex multi-modal attention including: + - Video Self-Attention (video -> video) + - Audio Self-Attention (audio -> audio) + - Video Cross-Attention (video -> text caption) + - Audio Cross-Attention (audio -> text caption) + - Video-to-Audio Cross-Attention + - Audio-to-Video Cross-Attention + + Args: + ctx: An `LTX2BlockContext` object containing all hidden states, timestep + embeddings, attention masks, rotary embeddings, and modulation + parameters needed for this layer's forward pass. + + Returns: + A tuple of `(output_hidden_states, output_audio_hidden_states)`. + """ + hidden_states = ctx.hidden_states + audio_hidden_states = ctx.audio_hidden_states + encoder_hidden_states = ctx.encoder_hidden_states + audio_encoder_hidden_states = ctx.audio_encoder_hidden_states + temb = ctx.temb + temb_audio = ctx.temb_audio + temb_ca_scale_shift = ctx.temb_ca_scale_shift + temb_ca_audio_scale_shift = ctx.temb_ca_audio_scale_shift + temb_ca_gate = ctx.temb_ca_gate + temb_ca_audio_gate = ctx.temb_ca_audio_gate + temb_prompt = ctx.temb_prompt + temb_prompt_audio = ctx.temb_prompt_audio + modality_mask = ctx.modality_mask + video_rotary_emb = ctx.video_rotary_emb + audio_rotary_emb = ctx.audio_rotary_emb + ca_video_rotary_emb = ctx.ca_video_rotary_emb + ca_audio_rotary_emb = ctx.ca_audio_rotary_emb + encoder_attention_mask = ctx.encoder_attention_mask + audio_encoder_attention_mask = ctx.audio_encoder_attention_mask + a2v_cross_attention_mask = ctx.a2v_cross_attention_mask + v2a_cross_attention_mask = ctx.v2a_cross_attention_mask + perturbation_mask = ctx.perturbation_mask batch_size = hidden_states.shape[0] axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) @@ -879,6 +925,7 @@ def __init__( # 3. Output Layer Scale/Shift Modulation parameters param_rng = rngs.params() + audio_param_rng = rngs.params() table_sharding = self.sharding_specs.scale_shift_table self.scale_shift_table = nnx.Param( nnx.with_partitioning( @@ -889,7 +936,7 @@ def __init__( nnx.with_partitioning( lambda key, shape: jax.random.normal(key, shape, dtype=self.weights_dtype) / jnp.sqrt(audio_inner_dim), table_sharding, - )(param_rng, (2, audio_inner_dim)) + )(audio_param_rng, (2, audio_inner_dim)) ) # 4. Rotary Positional Embeddings (RoPE) @@ -999,6 +1046,7 @@ def init_block(rngs): for _ in range(self.num_layers): block = LTX2VideoTransformerBlock( rngs=rngs, + sharding_specs=self.sharding_specs, dim=inner_dim, num_attention_heads=self.num_attention_heads, attention_head_dim=self.attention_head_dim, @@ -1085,6 +1133,36 @@ def __call__( return_dict: bool = True, perturbation_mask: Optional[jax.Array] = None, ) -> Any: + """ + Forward pass for the full LTX2 Video/Audio Diffusion Transformer. + + Args: + hidden_states: Video latent patches of shape `(batch, seq_len, in_channels)`. + audio_hidden_states: Audio latent patches of shape `(batch, audio_seq_len, audio_in_channels)`. + encoder_hidden_states: Text embeddings for video generation. + audio_encoder_hidden_states: Text embeddings for audio generation. + timestep: Timestep array for video diffusion. + audio_timestep: Optional timestep array for audio diffusion. If None, uses `timestep`. + sigma: Optional noise scale for video (for flow matching). + audio_sigma: Optional noise scale for audio. + encoder_attention_mask: Mask for video text embeddings. + audio_encoder_attention_mask: Mask for audio text embeddings. + num_frames: Number of video frames. + height: Height of the video frames. + width: Width of the video frames. + fps: Frames per second. + audio_num_frames: Number of audio frames. + video_coords: Optional pre-computed 3D coordinates for video RoPE. + audio_coords: Optional pre-computed 1D coordinates for audio RoPE. + attention_kwargs: Additional kwargs for the attention mechanisms. + use_cross_timestep: Whether to use a cross-modal timestep interaction. + modality_mask: Mask indicating which modality to drop/keep. + return_dict: If True, returns a dictionary. Otherwise, returns a tuple. + perturbation_mask: Optional mask for perturbing attention. + + Returns: + Output dict containing `sample` (video) and `audio_sample` (audio). + """ # Determine timestep for audio. audio_timestep = audio_timestep if audio_timestep is not None else timestep @@ -1152,9 +1230,8 @@ def __call__( temb_prompt_audio = None if use_cross_timestep: - assert ( - sigma is not None and audio_sigma is not None - ), "sigma and audio_sigma must be provided when use_cross_timestep is True" + if sigma is None or audio_sigma is None: + raise ValueError("sigma and audio_sigma must be provided when use_cross_timestep is True") video_ca_timestep = audio_sigma.flatten() audio_ca_timestep = sigma.flatten() else: @@ -1195,38 +1272,47 @@ def __call__( audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1]) # 5. Run transformer blocks with jax.named_scope("Transformer Blocks"): + base_context = LTX2BlockContext( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + temb_prompt=temb_prompt, + temb_prompt_audio=temb_prompt_audio, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + modality_mask=modality_mask, + ) + + def apply_block(block, context: LTX2BlockContext, mask, rngs_carry): + orig_perturbation_mask = context.perturbation_mask + context = dataclasses.replace(context, perturbation_mask=mask) + with jax.named_scope("Transformer Layer"): + hidden_states_out, audio_hidden_states_out = block(context) + context = dataclasses.replace( + context, + hidden_states=hidden_states_out.astype(context.hidden_states.dtype), + audio_hidden_states=audio_hidden_states_out.astype(context.audio_hidden_states.dtype), + perturbation_mask=orig_perturbation_mask, + ) + return context, rngs_carry + if perturbation_mask is None: - # Fast-path: No perturbation masking (standard LTX-2 or disabled STG) + def scan_fn_ltx2(carry, block): - hidden_states, audio_hidden_states, rngs_carry = carry - with jax.named_scope("Transformer Layer"): - hidden_states_out, audio_hidden_states_out = block( - hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, - temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, - temb_prompt=temb_prompt, - temb_prompt_audio=temb_prompt_audio, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, - perturbation_mask=None, - modality_mask=modality_mask, - ) - return ( - hidden_states_out.astype(hidden_states.dtype), - audio_hidden_states_out.astype(audio_hidden_states.dtype), - rngs_carry, - ), None + context, rngs_carry = carry + context, rngs_carry = apply_block(block, context, None, rngs_carry) + return (context, rngs_carry), None if self.scan_layers: rematted_scan_fn = self.gradient_checkpoint.apply( @@ -1235,40 +1321,22 @@ def scan_fn_ltx2(carry, block): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) - (hidden_states, audio_hidden_states, _), _ = nnx.scan( + carry = (base_context, nnx.Rngs(0)) + (final_context, _), _ = nnx.scan( rematted_scan_fn, length=self.num_layers, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - transform_metadata={nnx.PARTITION_NAME: "layers"}, )(carry, self.transformer_blocks) + hidden_states = final_context.hidden_states + audio_hidden_states = final_context.audio_hidden_states else: + current_context = base_context for block in self.transformer_blocks: - hidden_states, audio_hidden_states = block( - hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, - temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, - temb_prompt=temb_prompt, - temb_prompt_audio=temb_prompt_audio, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, - perturbation_mask=None, - modality_mask=modality_mask, - ) + current_context, _ = apply_block(block, current_context, None, None) + hidden_states = current_context.hidden_states + audio_hidden_states = current_context.audio_hidden_states else: - # Slow-path: Dynamic perturbation masking (LTX-2.3 STG enabled) masks = jnp.ones((self.num_layers, batch_size, 1, 1), dtype=self.dtype) for i in self.spatio_temporal_guidance_blocks: if i < self.num_layers: @@ -1277,35 +1345,9 @@ def scan_fn_ltx2(carry, block): def scan_fn_ltx23(carry, block_and_mask): block, mask = block_and_mask - hidden_states, audio_hidden_states, rngs_carry = carry - with jax.named_scope("Transformer Layer"): - hidden_states_out, audio_hidden_states_out = block( - hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, - temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, - temb_prompt=temb_prompt, - temb_prompt_audio=temb_prompt_audio, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, - perturbation_mask=mask, - modality_mask=modality_mask, - ) - return ( - hidden_states_out.astype(hidden_states.dtype), - audio_hidden_states_out.astype(audio_hidden_states.dtype), - rngs_carry, - ), None + context, rngs_carry = carry + context, rngs_carry = apply_block(block, context, mask, rngs_carry) + return (context, rngs_carry), None if self.scan_layers: rematted_scan_fn = self.gradient_checkpoint.apply( @@ -1314,39 +1356,22 @@ def scan_fn_ltx23(carry, block_and_mask): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) - (hidden_states, audio_hidden_states, _), _ = nnx.scan( + carry = (base_context, nnx.Rngs(0)) + (final_context, _), _ = nnx.scan( rematted_scan_fn, length=self.num_layers, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - transform_metadata={nnx.PARTITION_NAME: "layers"}, )(carry, (self.transformer_blocks, perturbation_mask_per_layer)) + hidden_states = final_context.hidden_states + audio_hidden_states = final_context.audio_hidden_states else: + current_context = base_context for i, block in enumerate(self.transformer_blocks): mask = perturbation_mask_per_layer[i] if perturbation_mask_per_layer is not None else None - hidden_states, audio_hidden_states = block( - hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, - temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, - temb_prompt=temb_prompt, - temb_prompt_audio=temb_prompt_audio, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, - perturbation_mask=mask, - modality_mask=modality_mask, - ) + current_context, _ = apply_block(block, current_context, mask, None) + hidden_states = current_context.hidden_states + audio_hidden_states = current_context.audio_hidden_states # 6. Output layers with jax.named_scope("Output Projection & Norm"): diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 2309c7b44..1603a793d 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -81,6 +81,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ std_text = jnp.std(noise_pred_text, axis=list(range(1, noise_pred_text.ndim)), keepdims=True) std_cfg = jnp.std(noise_cfg, axis=list(range(1, noise_cfg.ndim)), keepdims=True) + # Prevent division by zero + std_cfg = jnp.maximum(std_cfg, 1e-15) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -855,6 +857,8 @@ def _get_gemma_prompt_embeds( prompt = [p.strip() for p in prompt] + target_dtype = dtype if dtype is not None else jnp.bfloat16 + if self.text_encoder is not None: run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", False) if hasattr(self, "config") else False if run_text_encoder_on_tpu: @@ -872,30 +876,31 @@ def _get_gemma_prompt_embeds( # Distribute the batch dimension across available TPUs to prevent Softmax OOM # (reduces 512MB allocation down to 64MB per TPU for batch size 16) - devices = np.array(jax.devices()) - num_shards = 1 - for i in range(len(devices), 0, -1): - if text_input_ids.shape[0] % i == 0: - num_shards = i - break - - if num_shards > 1: - mesh = Mesh(devices[:num_shards], axis_names=("batch",)) - sharding = NamedSharding(mesh, P("batch")) + if hasattr(self, "mesh") and self.mesh is not None: + data_axis = self.mesh.axis_names[0] + sharding = NamedSharding(self.mesh, P(data_axis)) text_input_ids = jax.device_put(text_input_ids, sharding) prompt_attention_mask = jax.device_put(prompt_attention_mask, sharding) + else: + devices = np.array(jax.devices()) + num_shards = 1 + for i in range(len(devices), 0, -1): + if text_input_ids.shape[0] % i == 0: + num_shards = i + break + + if num_shards > 1: + mesh = Mesh(devices[:num_shards], axis_names=("batch",)) + sharding = NamedSharding(mesh, P("batch")) + text_input_ids = jax.device_put(text_input_ids, sharding) + prompt_attention_mask = jax.device_put(prompt_attention_mask, sharding) # Torchax wrapper returns tuple of hidden states natively text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True ) - prompt_embeds_list = [] - # Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT - for state in text_encoder_hidden_states: - prompt_embeds_list.append(state.astype(jnp.bfloat16)) - - prompt_embeds = prompt_embeds_list + prompt_embeds = jax.tree.map(lambda x: x.astype(target_dtype), list(text_encoder_hidden_states)) del text_encoder_hidden_states # Free memory prompt_attention_mask = prompt_attention_mask.astype(jnp.bool_) @@ -923,25 +928,16 @@ def _get_gemma_prompt_embeds( text_encoder_hidden_states = text_encoder_outputs.hidden_states del text_encoder_outputs # Free memory - prompt_embeds_list = [] - # Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT - for state in text_encoder_hidden_states: - state_np = state.cpu().to(torch.float32).numpy() - prompt_embeds_list.append(jnp.array(state_np, dtype=jnp.bfloat16)) - - prompt_embeds = prompt_embeds_list + prompt_embeds = jax.tree.map( + lambda state: jnp.array(state.cpu().to(torch.float32).numpy(), dtype=target_dtype), + list(text_encoder_hidden_states), + ) del text_encoder_hidden_states # Free PyTorch tensor memory prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().to(torch.float32).numpy(), dtype=jnp.bool_) else: raise ValueError("`text_encoder` is required to encode prompts.") - if dtype is not None: - if isinstance(prompt_embeds, list): - prompt_embeds = [state.astype(dtype) for state in prompt_embeds] - else: - prompt_embeds = prompt_embeds.astype(dtype) - if isinstance(prompt_embeds, list): _, seq_len, _ = prompt_embeds[0].shape prompt_embeds = [ @@ -1175,6 +1171,9 @@ def _create_noised_state(latents: jax.Array, noise_scale: float, generator: Opti else: # Fallback or expect noise to be handled otherwise? # pipeline prepare_latents typically generates noise. + max_logging.log( + "WARNING: No PRNG generator provided. Falling back to deterministic zero-seed noise (jax.random.key(0))." + ) noise = jax.random.normal(jax.random.key(0), latents.shape, dtype=latents.dtype) # Default fallback noised_latents = noise_scale * noise + (1 - noise_scale) * latents @@ -1555,12 +1554,22 @@ def __call__( audio_embeds_sharded = audio_embeds if not self.transformer.scan_layers: - activation_axes = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) - activation_axes_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed")) - spec = NamedSharding(self.mesh, P(*activation_axes)) - spec_audio = NamedSharding(self.mesh, P(*activation_axes_audio)) - video_embeds_sharded = jax.device_put(video_embeds, spec) - audio_embeds_sharded = jax.device_put(audio_embeds, spec_audio) + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + activation_axes = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + activation_axes_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed")) + + @jax.jit + def enforce_layout_act(x): + return jax.lax.with_sharding_constraint(x, activation_axes) + + @jax.jit + def enforce_layout_act_audio(x): + return jax.lax.with_sharding_constraint(x, activation_axes_audio) + + latents_jax = enforce_layout_act(latents_jax) + audio_latents_jax = enforce_layout_act_audio(audio_latents_jax) + video_embeds_sharded = enforce_layout_act(video_embeds_sharded) + audio_embeds_sharded = enforce_layout_act_audio(audio_embeds_sharded) timesteps_jax = jnp.array(timesteps, dtype=jnp.float32) t0_denoise = time.perf_counter() @@ -1595,6 +1604,8 @@ def __call__( self.scheduler.step, tuple(tuple(rule) if isinstance(rule, list) else rule for rule in self.config.logical_axis_rules), use_cross_timestep=use_cross_timestep, + do_cfg=do_cfg, + do_stg=do_stg, ) else: # Old Python loop path @@ -2021,14 +2032,6 @@ def transformer_forward_pass( @partial( jax.jit, static_argnames=( - "guidance_scale", - "stg_scale", - "modality_scale", - "guidance_rescale", - "audio_guidance_scale", - "audio_stg_scale", - "audio_modality_scale", - "audio_guidance_rescale", "latent_num_frames", "latent_height", "latent_width", @@ -2039,6 +2042,8 @@ def transformer_forward_pass( "scheduler_step", "logical_axis_rules", "use_cross_timestep", + "do_cfg", + "do_stg", ), ) def run_diffusion_loop( @@ -2071,15 +2076,14 @@ def run_diffusion_loop( logical_axis_rules, perturbation_mask=None, use_cross_timestep=False, + do_cfg=False, + do_stg=False, ): """Runs the diffusion loop.""" # pylint: disable=too-many-positional-arguments latents_jax = latents_jax.astype(jnp.float32) audio_latents_jax = audio_latents_jax.astype(jnp.float32) - do_cfg = guidance_scale > 1.0 - do_stg = stg_scale > 0.0 - # Helper functions matching Diffusers Delta formulation def convert_to_x0(lat, vel, sigma_t): return lat - vel * sigma_t @@ -2097,8 +2101,16 @@ def scan_body(carry, inputs): if not scan_layers: activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + activation_axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed")) latents_sharded = jax.lax.with_sharding_constraint(latents, activation_axis_names) - audio_latents_sharded = jax.lax.with_sharding_constraint(audio_latents, activation_axis_names) + audio_latents_sharded = jax.lax.with_sharding_constraint(audio_latents, activation_axis_names_audio) + video_embeds_sharded_constrained = jax.lax.with_sharding_constraint(video_embeds_sharded, activation_axis_names) + audio_embeds_sharded_constrained = jax.lax.with_sharding_constraint( + audio_embeds_sharded, activation_axis_names_audio + ) + else: + video_embeds_sharded_constrained = video_embeds_sharded + audio_embeds_sharded_constrained = audio_embeds_sharded # Forward Pass noise_pred, noise_pred_audio = transformer_forward_pass( @@ -2107,8 +2119,8 @@ def scan_body(carry, inputs): latents_sharded, audio_latents_sharded, t, - video_embeds_sharded, - audio_embeds_sharded, + video_embeds_sharded_constrained, + audio_embeds_sharded_constrained, new_attention_mask, new_attention_mask, latent_num_frames=latent_num_frames, @@ -2151,8 +2163,11 @@ def scan_body(carry, inputs): x0_combined = x0_text + cfg_delta + stg_delta + video_modality_delta - if guidance_rescale > 0: - x0_combined = rescale_noise_cfg(x0_combined, x0_text, guidance_rescale=guidance_rescale) + x0_combined = jax.lax.cond( + guidance_rescale > 0, + lambda: rescale_noise_cfg(x0_combined, x0_text, guidance_rescale=guidance_rescale), + lambda: x0_combined, + ) noise_pred = convert_to_vel(latents_step, x0_combined, sigma_t) @@ -2178,8 +2193,11 @@ def scan_body(carry, inputs): x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta + audio_modality_delta - if audio_guidance_rescale > 0: - x0_audio_combined = rescale_noise_cfg(x0_audio_combined, x0_audio_text, guidance_rescale=audio_guidance_rescale) + x0_audio_combined = jax.lax.cond( + audio_guidance_rescale > 0, + lambda: rescale_noise_cfg(x0_audio_combined, x0_audio_text, guidance_rescale=audio_guidance_rescale), + lambda: x0_audio_combined, + ) noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined, sigma_t) @@ -2214,10 +2232,6 @@ def scan_body(carry, inputs): initial_carry = (latents_jax, audio_latents_jax, scheduler_state) scan_inputs = (timesteps_jax, sigmas) - final_carry, _ = nnx.scan( - scan_body, - in_axes=(nnx.Carry, 0), - out_axes=(nnx.Carry, 0), - )(initial_carry, scan_inputs) + final_carry, _ = jax.lax.scan(scan_body, initial_carry, scan_inputs) return final_carry[0], final_carry[1] diff --git a/src/maxdiffusion/tests/ltx2/test_transformer_ltx2.py b/src/maxdiffusion/tests/ltx2/test_transformer_ltx2.py index f8c5b8d62..281b82784 100644 --- a/src/maxdiffusion/tests/ltx2/test_transformer_ltx2.py +++ b/src/maxdiffusion/tests/ltx2/test_transformer_ltx2.py @@ -25,6 +25,7 @@ from maxdiffusion import pyconfig from maxdiffusion.max_utils import create_device_mesh from maxdiffusion.models.ltx2.transformer_ltx2 import ( + LTX2BlockContext, LTX2VideoTransformerBlock, LTX2VideoTransformer3DModel, LTX2AdaLayerNormSingle, @@ -203,7 +204,7 @@ def test_ltx2_transformer_block(self): temb_ca_gate = jnp.zeros((batch_size, 1 * dim)) temb_ca_audio_gate = jnp.zeros((batch_size, 1 * audio_dim)) - output_hidden, output_audio = block( + ctx = LTX2BlockContext( hidden_states=hidden_states, audio_hidden_states=audio_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -216,6 +217,8 @@ def test_ltx2_transformer_block(self): temb_ca_audio_gate=temb_ca_audio_gate, ) + output_hidden, output_audio = block(ctx) + self.assertEqual(output_hidden.shape, hidden_states.shape) self.assertEqual(output_audio.shape, audio_hidden_states.shape)