From 96e9d82f137395598bed9801858f0b3cbc57b77d Mon Sep 17 00:00:00 2001 From: mbohlool Date: Tue, 16 Jun 2026 16:51:07 +0000 Subject: [PATCH] fix(ltx2): Fix VAE timing regression for large batch sizes Fixes a massive execution time regression (68s) in VAE decode by explicitly replicating VAE weights even when latents are batch-sharded. This forces XLA into an optimal data-parallel path, restoring the fast ~2s execution time while retaining batch-sharding OOM protections. --- src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 259fff593..2309c7b44 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -1818,10 +1818,9 @@ def convert_to_vel(lat, x0, sig): enable_dynamic_vae_sharding = ( getattr(self.config, "enable_dynamic_vae_sharding", True) if hasattr(self, "config") else True ) - if enable_dynamic_vae_sharding and batch_size > 2: max_logging.log( - f"[Tuning] Skipping VAE replication and disabling slicing to prevent HBM OOM for batch_size {batch_size} > 2" + f"[Tuning] Disabling VAE slicing and applying dynamic batch sharding to prevent HBM OOM for batch_size {batch_size} > 2" ) try: # Disable sequential slicing to avoid JAX concatenating 17GB arrays on the TPU @@ -1847,6 +1846,13 @@ def convert_to_vel(lat, x0, sig): mesh = latents.sharding.mesh replicated_sharding = NamedSharding(mesh, P()) latents = jax.lax.with_sharding_constraint(latents, replicated_sharding) + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.log(f"[Tuning] Failed to apply replicate VAE latents sharding: {e}") + + if replicate_vae: + try: + mesh = latents.sharding.mesh + replicated_sharding = NamedSharding(mesh, P()) # Replicate VAE weights graphdef, state = nnx.split(self.vae) state = jax.tree_util.tree_map( @@ -1854,7 +1860,7 @@ def convert_to_vel(lat, x0, sig): ) self.vae = nnx.merge(graphdef, state) except Exception as e: # pylint: disable=broad-exception-caught - max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}") + max_logging.log(f"[Tuning] Failed to replicate VAE weights: {e}") latent_processing_time += time.perf_counter() - t0_latent_processing timings["Latent Processing"] = latent_processing_time