Skip to content

Fix VRAM leak in tiler fallback in video VAEs#13073

Merged
comfyanonymous merged 3 commits intoComfy-Org:masterfrom
rattus128:prs/tiler-fallback-leak
Mar 20, 2026
Merged

Fix VRAM leak in tiler fallback in video VAEs#13073
comfyanonymous merged 3 commits intoComfy-Org:masterfrom
rattus128:prs/tiler-fallback-leak

Conversation

@rattus128
Copy link
Copy Markdown
Contributor

#13023 (specifically the secondary report)

Primary commit message

wan: vae: Don't recursion in local fns (move run_up)
Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.

  • for LTX (for consistency)

Also soft empty the cache so nvtop and friends look better.

Example test conditions:

Linux, 5090
WAN VAE encode -> decode 2048x2048x21f
--disable-cuda-malloc

This hack to get single tile memory trace:

diff --git a/comfy/sd.py b/comfy/sd.py
index e207bb0f..67206276 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -942,6 +942,8 @@ class VAE:
         self.throw_exception_if_invalid()
         pixel_samples = None
         do_tile = False
+        snapshot_file = "vae_decode_snapshot.pickle"
+        torch.cuda.memory._record_memory_history(max_entries=100000)
         if self.latent_dim == 2 and samples_in.ndim == 5:
             samples_in = samples_in[:, :, 0]
         try:
diff --git a/comfy/utils.py b/comfy/utils.py
index 78c491b9..521849af 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -1125,12 +1125,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
 
     output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
 
+    snapshot_file = "vae_decode_snapshot.pickle"
     for b in range(samples.shape[0]):
         s = samples[b:b+1]
 
         # handle entire input fitting in a single tile
         if all(s.shape[d+2] <= tile[d] for d in range(dims)):
             output[b:b+1] = function(s).to(output_device)
+            torch.cuda.memory._dump_snapshot(snapshot_file)
+            torch.cuda.memory._record_memory_history(enabled=None)
+            raise RuntimeError("early existing the tiler for test purposes")
             if pbar is not None:
                 pbar.update(1)
             continue
@@ -1151,6 +1155,9 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
                 upscaled.append(round(get_pos(d, pos)))
 
             ps = function(s_in).to(output_device)
+            torch.cuda.memory._dump_snapshot(snapshot_file)
+            torch.cuda.memory._record_memory_history(enabled=None)
+            raise RuntimeError("early existing the tiler for test purposes")
             mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
 
             for d in range(2, dims + 2):
image

Before:

image

After:

image

I can't reproduce the issue on LTX, but apply the same change anyway.

Regression Tests:
LTX2.3 ✅
WAN2.2 I2V ✅

This doesnt cost a lot and creates the expected VRAM reduction in
resource monitors when you fallback to tiler.
Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
Mov the recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

The PR refactors two VAE decoder implementations to extract nested run_up helper functions into dedicated methods with explicit parameters. In causal_video_autoencoder.py, the Decoder.run_up method consolidates upsampling recursion and chunking logic. In vae.py, the Decoder3d.run_up method encapsulates recursive frame processing. Additionally, memory management calls are added to the tiled encoding and decoding paths in sd.py to invoke soft cache clearing before processing tiles.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Fix VRAM leak in tiler fallback in video VAEs' directly and clearly describes the main objective of the PR: addressing a VRAM leak issue in the tiled fallback path for video VAEs.
Description check ✅ Passed The description is comprehensive and directly related to the changeset, explaining the root cause (nested closure self-reference cycles creating cyclic garbage), the solution (moving run_up to a class method), and including test results and debug details.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
comfy/ldm/lightricks/vae/causal_video_autoencoder.py (1)

604-641: ⚠️ Potential issue | 🔴 Critical

Initialize scaled_timestep for non-conditioned decode paths (runtime blocker).

scaled_timestep is only set inside if self.timestep_conditioning: but is always passed to run_up() at line 641. When self.timestep_conditioning is False, this raises UnboundLocalError.

Proposed fix
         timestep_shift_scale = None
+        scaled_timestep = None
         if self.timestep_conditioning:
             assert (
                 timestep is not None
             ), "should pass timestep with timestep_conditioning=True"
             scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ldm/lightricks/vae/causal_video_autoencoder.py` around lines 604 - 641,
The variable scaled_timestep is only assigned inside the if
self.timestep_conditioning branch but is always passed to run_up(...), causing
an UnboundLocalError when timestep_conditioning is False; initialize
scaled_timestep (e.g., set scaled_timestep = None) before the conditional or
ensure you pass a defined fallback to run_up, updating the block around
scaled_timestep and the call to self.run_up(...) so run_up receives a valid
value regardless of self.timestep_conditioning.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@comfy/ldm/lightricks/vae/causal_video_autoencoder.py`:
- Around line 604-641: The variable scaled_timestep is only assigned inside the
if self.timestep_conditioning branch but is always passed to run_up(...),
causing an UnboundLocalError when timestep_conditioning is False; initialize
scaled_timestep (e.g., set scaled_timestep = None) before the conditional or
ensure you pass a defined fallback to run_up, updating the block around
scaled_timestep and the call to self.run_up(...) so run_up receives a valid
value regardless of self.timestep_conditioning.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 12f3042d-2d1b-44ae-9dc6-ec82d0567ac7

📥 Commits

Reviewing files that changed from the base of the PR and between 8458ae2 and 6b138a8.

📒 Files selected for processing (3)
  • comfy/ldm/lightricks/vae/causal_video_autoencoder.py
  • comfy/ldm/wan/vae.py
  • comfy/sd.py

@comfyanonymous comfyanonymous merged commit 82b868a into Comfy-Org:master Mar 20, 2026
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants