Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/maxdiffusion/generate_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):

# Export videos
for i in range(len(videos)):
video_path = f"{filename_prefix}ltx2_output_{getattr(config, 'seed', 0)}_{i}.mp4"
model_name_prefix = getattr(config, "model_name", "ltx2").replace(".", "_")
video_path = f"{filename_prefix}{model_name_prefix}_output_{getattr(config, 'seed', 0)}_{i}.mp4"
audio_i = audios[i] if audios is not None else None

audio_format = getattr(config, "audio_format", "s16")
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ def __init__(
self,
mesh: Mesh,
attention_kernel: str,
scale: int,
scale: float,
heads: int,
dim_head: int,
use_memory_efficient_attention: bool = False,
Expand Down Expand Up @@ -1475,7 +1475,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
class AttentionOp(nn.Module):
mesh: Mesh
attention_kernel: str
scale: int
scale: float
heads: int
dim_head: int
use_memory_efficient_attention: bool = False
Expand Down
17 changes: 3 additions & 14 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,8 @@ def apply_split_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
first_x = split_x[..., 0, :]
second_x = split_x[..., 1, :]

cos_u = jnp.expand_dims(cos, axis=-2)
sin_u = jnp.expand_dims(sin, axis=-2)

out = split_x * cos_u

out_first = out[..., 0, :] - second_x * sin_u.squeeze(-2)
out_second = out[..., 1, :] + first_x * sin_u.squeeze(-2)
out_first = first_x * cos - second_x * sin
out_second = second_x * cos + first_x * sin

out = jnp.stack([out_first, out_second], axis=-2)
out = out.reshape(*out.shape[:-2], last_dim)
Expand Down Expand Up @@ -176,12 +171,6 @@ def prepare_video_coords(
patch_ends = grid + patch_size_delta

# Combine start and end coordinates
latent_coords = jnp.stack([grid, patch_ends], axis=-1) # [3, N_F, N_H, N_W, 2]
latent_coords = latent_coords.transpose(1, 2, 3, 0, 4) # [N_F, N_H, N_W, 3, 2]
latent_coords = latent_coords.reshape(-1, 3, 2) # [num_patches, 3, 2]
latent_coords = jnp.expand_dims(latent_coords, 0) # [1, num_patches, 3, 2]
latent_coords = jnp.tile(latent_coords, (batch_size, 1, 1, 1)) # [B, num_patches, 3, 2]

latent_coords = jnp.stack([grid, patch_ends], axis=-1) # [3, N_F, N_H, N_W, 2]
latent_coords = latent_coords.reshape(3, -1, 2) # [3, num_patches, 2]
latent_coords = jnp.expand_dims(latent_coords, 0) # [1, 3, num_patches, 2]
Expand Down Expand Up @@ -485,7 +474,7 @@ def __call__(
# 3. Apply RoPE
with jax.named_scope("Apply RoPE"):
if rotary_emb is not None:
if hasattr(self, "rope_type") and self.rope_type == "split":
if self.rope_type == "split":
# Split RoPE: passing full freqs [B, H, S, D//2]
# apply_split_rotary_emb handles reshaping query/key

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
from torchax import interop, default_env

# --- Monkeypatch transformers masking_utils to avoid torchax integer tracing bug ---
import transformers.masking_utils

_orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay
from unittest import mock


def _patched_sliding_window_overlay(sliding_window: int):
Expand Down Expand Up @@ -57,8 +55,7 @@ def __call__(
self, input_ids: jax.Array, attention_mask: jax.Array, output_hidden_states: bool = True
) -> Tuple[jax.Array, ...]:
# Dynamically patch transformers.masking_utils only during the duration of this call
transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
try:
with mock.patch("transformers.masking_utils.sliding_window_overlay", _patched_sliding_window_overlay):
with default_env():
input_ids = interop.torch_view(input_ids)
attention_mask = interop.torch_view(attention_mask)
Expand All @@ -72,9 +69,6 @@ def __call__(
output_hidden_states=output_hidden_states,
)
return interop.jax_view(output)
finally:
# Restore original behavior to prevent side effects on other potential models in same env
transformers.masking_utils.sliding_window_overlay = _orig_sliding_window_overlay

@staticmethod
def _forward_inner(model, input_ids, attention_mask, output_hidden_states=True):
Expand Down
Loading
Loading