diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 259fff593..2309c7b44 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -1818,10 +1818,9 @@ def convert_to_vel(lat, x0, sig): enable_dynamic_vae_sharding = ( getattr(self.config, "enable_dynamic_vae_sharding", True) if hasattr(self, "config") else True ) - if enable_dynamic_vae_sharding and batch_size > 2: max_logging.log( - f"[Tuning] Skipping VAE replication and disabling slicing to prevent HBM OOM for batch_size {batch_size} > 2" + f"[Tuning] Disabling VAE slicing and applying dynamic batch sharding to prevent HBM OOM for batch_size {batch_size} > 2" ) try: # Disable sequential slicing to avoid JAX concatenating 17GB arrays on the TPU @@ -1847,6 +1846,13 @@ def convert_to_vel(lat, x0, sig): mesh = latents.sharding.mesh replicated_sharding = NamedSharding(mesh, P()) latents = jax.lax.with_sharding_constraint(latents, replicated_sharding) + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.log(f"[Tuning] Failed to apply replicate VAE latents sharding: {e}") + + if replicate_vae: + try: + mesh = latents.sharding.mesh + replicated_sharding = NamedSharding(mesh, P()) # Replicate VAE weights graphdef, state = nnx.split(self.vae) state = jax.tree_util.tree_map( @@ -1854,7 +1860,7 @@ def convert_to_vel(lat, x0, sig): ) self.vae = nnx.merge(graphdef, state) except Exception as e: # pylint: disable=broad-exception-caught - max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}") + max_logging.log(f"[Tuning] Failed to replicate VAE weights: {e}") latent_processing_time += time.perf_counter() - t0_latent_processing timings["Latent Processing"] = latent_processing_time