Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,10 +1818,9 @@ def convert_to_vel(lat, x0, sig):
enable_dynamic_vae_sharding = (
getattr(self.config, "enable_dynamic_vae_sharding", True) if hasattr(self, "config") else True
)

if enable_dynamic_vae_sharding and batch_size > 2:
max_logging.log(
f"[Tuning] Skipping VAE replication and disabling slicing to prevent HBM OOM for batch_size {batch_size} > 2"
f"[Tuning] Disabling VAE slicing and applying dynamic batch sharding to prevent HBM OOM for batch_size {batch_size} > 2"
)
try:
# Disable sequential slicing to avoid JAX concatenating 17GB arrays on the TPU
Expand All @@ -1847,14 +1846,21 @@ def convert_to_vel(lat, x0, sig):
mesh = latents.sharding.mesh
replicated_sharding = NamedSharding(mesh, P())
latents = jax.lax.with_sharding_constraint(latents, replicated_sharding)
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.log(f"[Tuning] Failed to apply replicate VAE latents sharding: {e}")

if replicate_vae:
try:
mesh = latents.sharding.mesh
replicated_sharding = NamedSharding(mesh, P())
# Replicate VAE weights
graphdef, state = nnx.split(self.vae)
state = jax.tree_util.tree_map(
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state
)
self.vae = nnx.merge(graphdef, state)
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}")
max_logging.log(f"[Tuning] Failed to replicate VAE weights: {e}")

latent_processing_time += time.perf_counter() - t0_latent_processing
timings["Latent Processing"] = latent_processing_time
Expand Down
Loading