Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/comfy_types/node_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class IO(StrEnum):
BBOX = "BBOX"
SEGS = "SEGS"
VIDEO = "VIDEO"
IMAGE_STREAM = "IMAGE_STREAM"

ANY = "*"
"""Always matches any type, but at a price.
Expand Down
165 changes: 130 additions & 35 deletions comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@

ops = comfy.ops.disable_weight_init

class RunUpState:
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None):
self.timestep_shift_scale = timestep_shift_scale
self.scaled_timestep = scaled_timestep
self.checkpoint_fn = checkpoint_fn
self.max_chunk_size = max_chunk_size
self.output_shape = output_shape
self.output_dtype = output_dtype
self.output_frames = output_frames
self.pending_samples = []

def in_meta_context():
return torch.device("meta") == torch.empty(0).device

Expand All @@ -26,6 +37,14 @@ def mark_conv3d_ended(module):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)

def clear_temporal_cache_state(module):
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
for _, m in module.named_modules():
if hasattr(m, "temporal_cache_state"):
m.temporal_cache_state.pop(tid, None)

def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)

Expand Down Expand Up @@ -315,13 +334,7 @@ def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
clear_temporal_cache_state(self)


MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
Expand Down Expand Up @@ -530,58 +543,70 @@ def __init__(
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
self.temporal_cache_state = {}


def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)

def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
if run_up_state.timestep_shift_scale is not None:
shift, scale = run_up_state.timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
if output_buffer is None:
run_up_state.output_frames = sample
return
output_slice = output_buffer[:, :, output_offset[0]:output_offset[0] + sample.shape[2]]
t = output_slice.shape[2]
output_slice.copy_(sample[:, :, :t])
output_offset[0] += t
if t < sample.shape[2]:
run_up_state.output_frames = sample[:, :, t:]
return

up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
sample = run_up_state.checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=run_up_state.scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = run_up_state.checkpoint_fn(up_block)(sample, causal=self.causal)

if sample is None or sample.shape[2] == 0:
return

total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
num_chunks = (total_bytes + run_up_state.max_chunk_size - 1) // run_up_state.max_chunk_size

if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
#chunker or output frame stash will do the work anyway. so unchanged.
self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
samples = list(torch.chunk(sample, chunks=num_chunks, dim=2))

for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
while len(samples):
if output_buffer is None or output_offset[0] == output_buffer.shape[2]:
run_up_state.pending_samples.append((idx + 1, samples, ended))
return
self.run_up(idx + 1, [samples.pop(0)], ended and len(samples) == 1, run_up_state, output_buffer, output_offset)

def forward_orig(
self,
Expand All @@ -591,6 +616,7 @@ def forward_orig(
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
output_shape = self.decode_output_shape(sample.shape)

mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal)
Expand Down Expand Up @@ -630,29 +656,89 @@ def forward_orig(
)
timestep_shift_scale = ada_values.unbind(dim=1)

if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]

max_chunk_size = get_max_chunk_size(sample.device)
run_up_state = RunUpState(
timestep_shift_scale=timestep_shift_scale,
scaled_timestep=scaled_timestep,
checkpoint_fn=checkpoint_fn,
max_chunk_size=get_max_chunk_size(sample.device),
output_shape=output_shape,
output_dtype=sample.dtype,
)
self.temporal_cache_state[threading.get_ident()] = run_up_state

self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset)

return output_buffer

def forward(self, *args, **kwargs):
def forward_start(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
):
try:
return self.forward_orig(*args, **kwargs)
return self.forward_orig(sample, timestep=timestep, output_buffer=None)
except Exception:
clear_temporal_cache_state(self)
raise

def forward_resume(self, output_t: int):
tid = threading.get_ident()
run_up_state = self.temporal_cache_state.get(tid, None)
if run_up_state is None:
return None

output_shape = list(run_up_state.output_shape)
output_shape[2] = output_t
output_buffer = torch.empty(
output_shape,
dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]

try:
if run_up_state.output_frames is not None:
output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]]
t = output_slice.shape[2]
output_slice.copy_(run_up_state.output_frames[:, :, :t])
output_offset[0] += t
run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:]

pending_samples = run_up_state.pending_samples
run_up_state.pending_samples = []
while len(pending_samples):
idx, samples, ended = pending_samples.pop(0)
while len(samples):
if output_offset[0] == output_buffer.shape[2]:
pending_samples = [(idx, samples, ended)] + pending_samples
run_up_state.pending_samples.extend(pending_samples)
return output_buffer
sample1 = samples.pop(0)
self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset)

if run_up_state.output_frames is None and not run_up_state.pending_samples:
clear_temporal_cache_state(self)
return output_buffer[:, :, :output_offset[0]]
except Exception:
clear_temporal_cache_state(self)
raise

def forward(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
):
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
try:
return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
clear_temporal_cache_state(self)


class UNetMidBlock3D(nn.Module):
Expand Down Expand Up @@ -1302,6 +1388,15 @@ def encode(self, x, device=None):
def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)

def decode_start(self, x):
clear_temporal_cache_state(self.decoder)
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
Comment on lines +1391 to +1395
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Seed streamed decode once per run.

The #TODO: seed is observable here: every decode_start() reset injects fresh random noise, so the same latent is no longer replayable across two drains of the same stream. That breaks the new resettable-stream behavior and can make preview/save passes disagree if the source is re-executed. The noise/seed needs to be captured once for the run and reused for all resumed chunks.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ldm/lightricks/vae/causal_video_autoencoder.py` around lines 1391 -
1395, decode_start currently injects fresh random noise every call which breaks
replayability; change it to generate and cache the noise once per run and reuse
it for subsequent decode_start calls. Specifically, add an instance field (e.g.,
self._decode_noise or self._decode_generator + self._decode_noise) and, in
decode_start, if that field is unset create the noise deterministically (use
torch.randn_like with a torch.Generator seeded once at run start) scaled by
self.decode_noise_scale and reused on later calls; keep
clear_temporal_cache_state(self.decoder) and continue to pass
timestep=self.decode_timestep to self.decoder.forward_start. Ensure there is
also a clear/reset method that clears the cached noise when a new run is
intentionally started.


def decode_chunk(self, output_t: int):
return self.decoder.forward_resume(output_t)

def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
Expand Down
11 changes: 11 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,17 @@ def decode(self, samples_in, vae_options={}):
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples

def decode_output_shape(self, samples_shape):
self.throw_exception_if_invalid()
if hasattr(self.first_stage_model, "decode_output_shape"):
return self.first_stage_model.decode_output_shape(samples_shape)
raise RuntimeError("This VAE does not expose decode output shape information.")

def decode_stream_start(self, samples_in):
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype))

def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
Expand Down
2 changes: 2 additions & 0 deletions comfy_api/input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from comfy_api.latest._input import (
ImageInput,
AudioInput,
ImageStreamInput,
MaskInput,
LatentInput,
VideoInput,
Expand All @@ -14,6 +15,7 @@
__all__ = [
"ImageInput",
"AudioInput",
"ImageStreamInput",
"MaskInput",
"LatentInput",
"VideoInput",
Expand Down
6 changes: 6 additions & 0 deletions comfy_api/input/image_stream_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._input.image_stream_types import ImageStreamInput

__all__ = [
"ImageStreamInput",
]
3 changes: 2 additions & 1 deletion comfy_api/latest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input import ImageInput, AudioInput, ImageStreamInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from . import _io_public as io
Expand Down Expand Up @@ -131,6 +131,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
class Input:
Image = ImageInput
Audio = AudioInput
ImageStream = ImageStreamInput
Mask = MaskInput
Latent = LatentInput
Video = VideoInput
Expand Down
2 changes: 2 additions & 0 deletions comfy_api/latest/_input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .image_stream_types import ImageStreamInput
from .video_types import VideoInput

__all__ = [
"ImageInput",
"AudioInput",
"ImageStreamInput",
"VideoInput",
"MaskInput",
"LatentInput",
Expand Down
Loading
Loading