Skip to content

fix(ltx2): Fix VAE timing regression for large batch sizes#421

Merged
copybara-service[bot] merged 1 commit into
mainfrom
fix-vae-timing-regression
Jun 17, 2026
Merged

fix(ltx2): Fix VAE timing regression for large batch sizes#421
copybara-service[bot] merged 1 commit into
mainfrom
fix-vae-timing-regression

Conversation

@mbohlool

Copy link
Copy Markdown
Collaborator

Fix VAE timing regression for large batch sizes

Root Cause:
In commit 7b28885, an optimization was added to prevent OOM errors for large batch sizes (batch_size > 2) by batch-sharding the latents and disabling sequential slicing. However, this logic used an elif replicate_vae: block, which caused the explicit replication of VAE weights to be entirely skipped for large batch sizes.

Without explicit weight replication, the XLA SPMD partitioner attempts to match the sharding of the input latents (which are batch-sharded) with the VAE decode computation. Because vae.decode involves fully-replicated noise injection and massive 3D convolutions, XLA heuristically decides to insert enormous amounts of cross-device communication (AllGather/AllReduce) to shard the weights or activations, ballooning the execution time from ~2.8s to ~68.5s for non-upsampled latents.

(Note: For upsampled latents, the memory layout generated by the JIT-compiled upsampler bypasses this XLA heuristic trap, allowing it to execute quickly in ~1.5s, which masked the issue).

Fix:
This PR decouples the batch-sharding of latents from the replication of VAE weights. It explicitly applies a full replication constraint NamedSharding(mesh, P()) to the VAE weights in all cases where replicate_vae is True, even if latents are batch-sharded. This forces XLA into the optimal data-parallel compilation path, restoring the fast ~1.5s - ~2.8s execution time for all scenarios without risking the concatenation-related OOMs.

@mbohlool mbohlool requested a review from entrpn as a code owner June 16, 2026 17:09
@github-actions

Copy link
Copy Markdown

@mbohlool mbohlool requested a review from prishajain1 June 16, 2026 17:11
Fixes a massive execution time regression (68s) in VAE decode by explicitly replicating VAE weights even when latents are batch-sharded. This forces XLA into an optimal data-parallel path, restoring the fast ~2s execution time while retaining batch-sharding OOM protections.
@mbohlool mbohlool force-pushed the fix-vae-timing-regression branch from 27eda58 to 96e9d82 Compare June 16, 2026 17:13
@copybara-service copybara-service Bot merged commit 9616d1c into main Jun 17, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants