Skip to content

Commit 1b660e5

Browse files
authored
Merge PR #425 from Kosinkadink/develop - Visualize Context Options nodes
Visualize Context Options nodes
2 parents f3b24c1 + c8480f9 commit 1b660e5

File tree

4 files changed

+261
-19
lines changed

4 files changed

+261
-19
lines changed

animatediff/context.py

Lines changed: 182 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Union
22

3+
import torch
34
import torchvision
4-
import PIL
5+
from PIL import Image, ImageFont, ImageDraw
56

67
import numpy as np
78
from torch import Tensor
89

10+
import comfy.samplers
911
from comfy.model_base import BaseModel
12+
from comfy.model_patcher import ModelPatcher
1013

1114
from .utils_motion import get_sorted_list_via_attr
1215

@@ -76,7 +79,7 @@ def __init__(self):
7679
self._current_context: ContextOptions = None
7780
self._current_used_steps: int = 0
7881
self._current_index: int = 0
79-
self.step = 0
82+
self._step = 0
8083

8184
def reset(self):
8285
self._current_context = None
@@ -85,6 +88,15 @@ def reset(self):
8588
self.step = 0
8689
self._set_first_as_current()
8790

91+
@property
92+
def step(self):
93+
return self._step
94+
@step.setter
95+
def step(self, value: int):
96+
self._step = value
97+
if self._current_context is not None:
98+
self._current_context.step = value
99+
88100
@classmethod
89101
def default(cls):
90102
def_context = ContextOptions()
@@ -492,15 +504,176 @@ class Colors:
492504
CYAN = (0, 255, 255)
493505

494506

507+
class BorderWidth:
508+
INDEXES = 2
509+
CONTEXT = 4
510+
511+
495512
class VisualizeSettings:
496-
def __init__(self, img_width, img_height, video_length):
497-
self.img_width = img_width
498-
self.img_height = img_height
513+
def __init__(self, img_width: int, video_length: int):
499514
self.video_length = video_length
515+
self.img_width = img_width
500516
self.grid = img_width // video_length
517+
self.img_height = self.grid * 5
501518
self.pil_to_tensor = torchvision.transforms.Compose([torchvision.transforms.PILToTensor()])
519+
self.font_size = int(self.grid * 0.5)
520+
self.font = ImageFont.load_default(size=self.font_size)
521+
#self.title_font = ImageFont.load_default(size=int(self.font_size * 1.5))
522+
self.title_font = ImageFont.load_default(size=int(self.font_size * 1.2))
502523

524+
self.background_color = Colors.BLACK
525+
self.grid_outline_color = Colors.WHITE
526+
self.start_idx_fill_color = Colors.MAGENTA
527+
self.subidx_end_color = Colors.YELLOW
503528

504-
def generate_context_visualization(context_opts: ContextOptionsGroup, model: BaseModel, width=1440, height=200, video_length=32, start_step=0, end_step=20):
505-
vs = VisualizeSettings(width, height, video_length)
506-
pass
529+
self.context_color = Colors.GREEN
530+
self.view_color = Colors.RED
531+
532+
533+
class GridDisplay:
534+
def __init__(self, draw: ImageDraw.ImageDraw, vs: VisualizeSettings, home_x: int=0, home_y: int=0):
535+
self.home_x = home_x
536+
self.home_y = home_y
537+
self.draw = draw
538+
self.vs = vs
539+
540+
541+
def get_text_xy(input: str, font: ImageFont, x: int, y: int, centered=True):
542+
return (x, y,)
543+
544+
545+
def draw_text(text: str, font: ImageFont, gd: GridDisplay, x: int, y: int, color=Colors.WHITE, centered=True):
546+
x, y = get_text_xy(text, font, x, y, centered=centered)
547+
gd.draw.text(xy=(gd.home_x+x, gd.home_y+y), text=text, fill=color, font=font)
548+
549+
550+
def draw_first_grid_row(total_length: int, gd: GridDisplay, start_idx=-1):
551+
vs = gd.vs
552+
# the first row is white squares, with the indexes drawed in
553+
for i in range(total_length):
554+
x1 = gd.home_x+(vs.grid*i)
555+
y1 = gd.home_y
556+
x2 = x1 + vs.grid
557+
y2 = y1 + vs.grid
558+
559+
fill = None
560+
if i==start_idx:
561+
fill=vs.start_idx_fill_color
562+
gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill, outline=vs.grid_outline_color, width=BorderWidth.INDEXES)
563+
draw_text(text=str(i), font=vs.font, gd=gd, x=vs.grid*i, y=0)
564+
565+
566+
def draw_subidxs(window: list[int], gd: GridDisplay, y_grid_offset: int, color: tuple):
567+
vs = gd.vs
568+
# with no indexes drawed in- just solid squares, mostly
569+
y_offset = vs.grid * y_grid_offset
570+
for i, val in enumerate(window):
571+
x1 = gd.home_x+(vs.grid*val)
572+
y1 = gd.home_y+y_offset
573+
x2 = x1 + vs.grid
574+
y2 = y1 + vs.grid
575+
fill_color = color
576+
# if at an end of indexes, make inside be different color
577+
if i == 0 or i == len(window)-1:
578+
fill_color = vs.subidx_end_color
579+
gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill_color, outline=color, width=BorderWidth.CONTEXT)
580+
581+
582+
def draw_context(window: list[int], gd: GridDisplay):
583+
draw_subidxs(window=window, gd=gd, y_grid_offset=1, color=gd.vs.context_color)
584+
585+
586+
def draw_view(window: list[int], gd: GridDisplay):
587+
draw_subidxs(window=window, gd=gd, y_grid_offset=2, color=gd.vs.view_color)
588+
589+
590+
def generate_context_visualization(context_opts: ContextOptionsGroup, model: ModelPatcher, sampler_name: str=None, scheduler: str=None,
591+
width=1440, height=200, video_length=32,
592+
steps=None, start_step=None, end_step=None, sigmas=None, force_full_denoise=False, denoise=None):
593+
context_opts = context_opts.clone()
594+
vs = VisualizeSettings(width, video_length)
595+
all_imgs = []
596+
597+
if sigmas is None:
598+
sampler = comfy.samplers.KSampler(
599+
model=model, steps=steps, device="cpu", sampler=sampler_name, scheduler=scheduler,
600+
denoise=denoise, model_options=model.model_options,
601+
)
602+
sigmas = sampler.sigmas
603+
if end_step is not None and end_step < (len(sigmas) - 1):
604+
sigmas = sigmas[:end_step + 1]
605+
if force_full_denoise:
606+
sigmas[-1] = 0
607+
if start_step is not None:
608+
if start_step < (len(sigmas) - 1):
609+
sigmas = sigmas[start_step:]
610+
# remove last sigma, as sampling uses pairs of sigmas at a time (fence post problem)
611+
sigmas = sigmas[:-1]
612+
613+
context_opts.reset()
614+
context_opts.initialize_timesteps(model.model)
615+
616+
if start_step is None:
617+
start_step = 0 # use this in case start_step is provided, to display accurate step
618+
if steps is None:
619+
steps = len(sigmas)
620+
621+
for i, t in enumerate(sigmas):
622+
# make context_opts reflect current step/sigma
623+
context_opts.prepare_current_context([t])
624+
context_opts.step = start_step+i
625+
626+
# check if context should even be active in this case
627+
context_active = True
628+
if video_length < context_opts.context_length:
629+
context_active = False
630+
elif video_length == context_opts.context_length and not context_opts.use_on_equal_length:
631+
context_active = False
632+
633+
if context_active:
634+
context_windows = get_context_windows(num_frames=video_length, opts=context_opts)
635+
else:
636+
context_windows = [list(range(video_length))]
637+
start_idx = -1
638+
for j,window in enumerate(context_windows):
639+
repeat_count = 0
640+
view_windows = []
641+
total_repeats = 1
642+
view_options = context_opts.view_options
643+
if view_options is not None:
644+
view_active = True
645+
if len(window) < view_options.context_length:
646+
view_active = False
647+
elif video_length == view_options.context_length and not view_options.use_on_equal_length:
648+
view_active = False
649+
if view_active:
650+
view_windows = get_context_windows(num_frames=len(window), opts=view_options)
651+
total_repeats = len(view_windows)
652+
while total_repeats > repeat_count:
653+
# create new frame
654+
frame: Image = Image.new(mode="RGB", size=(vs.img_width, vs.img_height), color=vs.background_color)
655+
draw = ImageDraw.Draw(frame)
656+
gd = GridDisplay(draw=draw, vs=vs, home_x=0, home_y=vs.grid)
657+
# if views present, do view stuff
658+
if len(view_windows) > 0:
659+
converted_view = [window[x] for x in view_windows[repeat_count]]
660+
draw_view(window=converted_view, gd=gd)
661+
# draw context_type + current step
662+
title_str = f"{context_opts.context_schedule} - Step {context_opts.step+1}/{steps} (Context {j+1}/{len(context_windows)})"
663+
if len(view_windows) > 0:
664+
title_str = f"{title_str} (View {repeat_count+1}/{len(view_windows)})"
665+
draw_text(text=title_str, font=vs.title_font, gd=gd, x=0-gd.home_x, y=0-gd.home_y, centered=False)
666+
# draw first row (total length, white)
667+
if j == 0:
668+
start_idx = window[0]
669+
draw_first_grid_row(total_length=video_length, gd=gd, start_idx=start_idx)
670+
# draw context row
671+
draw_context(window=window, gd=gd)
672+
# save image + iterate repeat_count
673+
img: Tensor = vs.pil_to_tensor(frame)
674+
all_imgs.append(img)
675+
repeat_count += 1
676+
677+
images = torch.stack(all_imgs)
678+
images = images.movedim(1, -1).to(torch.float32)
679+
return images

animatediff/nodes.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
NoisedImageInjectionNode, NoisedImageInjectOptionsNode)
2626
from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode)
2727
from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode,
28-
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsInt)
28+
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode,
29+
VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom)
2930
from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode,
3031
WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode,
3132
WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode)
@@ -58,7 +59,9 @@
5859
"ADE_ViewsOnlyContextOptions": ViewAsContextOptionsNode,
5960
"ADE_BatchedContextOptions": BatchedContextOptionsNode,
6061
"ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy
61-
#"ADE_VisualizeContextOptions": VisualizeContextOptionsInt,
62+
"ADE_VisualizeContextOptionsK": VisualizeContextOptionsK,
63+
"ADE_VisualizeContextOptionsKAdv": VisualizeContextOptionsKAdv,
64+
"ADE_VisualizeContextOptionsSCustom": VisualizeContextOptionsSCustom,
6265
# View Opts
6366
"ADE_StandardStaticViewOptions": StandardStaticViewOptionsNode,
6467
"ADE_StandardUniformViewOptions": StandardUniformViewOptionsNode,
@@ -180,7 +183,9 @@
180183
"ADE_ViewsOnlyContextOptions": "Context Options◆Views Only [VRAM⇈] 🎭🅐🅓",
181184
"ADE_BatchedContextOptions": "Context Options◆Batched [Non-AD] 🎭🅐🅓",
182185
"ADE_AnimateDiffUniformContextOptions": "Context Options◆Looped Uniform 🎭🅐🅓", # Legacy
183-
"ADE_VisualizeContextOptions": "Visualize Context Options 🎭🅐🅓",
186+
"ADE_VisualizeContextOptionsK": "Visualize Context Options (K.) 🎭🅐🅓",
187+
"ADE_VisualizeContextOptionsKAdv": "Visualize Context Options (K.Adv.) 🎭🅐🅓",
188+
"ADE_VisualizeContextOptionsSCustom": "Visualize Context Options (S.Cus.) 🎭🅐🅓",
184189
# View Opts
185190
"ADE_StandardStaticViewOptions": "View Options◆Standard Static 🎭🅐🅓",
186191
"ADE_StandardUniformViewOptions": "View Options◆Standard Uniform 🎭🅐🅓",

animatediff/nodes_context.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import torch
22
from torch import Tensor
33

4+
import comfy.samplers
45
from comfy.model_patcher import ModelPatcher
56

6-
from .context import ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules
7-
from .utils_model import BIGMAX
7+
from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules,
8+
generate_context_visualization)
9+
from .utils_model import BIGMAX, MAX_RESOLUTION
810

911

1012
LENGTH_MAX = 128 # keep an eye on these max values;
@@ -353,16 +355,20 @@ def create_options(self, view_length: int, view_overlap: int, view_stride: int,
353355
return (view_options,)
354356

355357

356-
class VisualizeContextOptionsInt:
358+
class VisualizeContextOptionsKAdv:
357359
@classmethod
358360
def INPUT_TYPES(s):
359361
return {
360362
"required": {
361363
"model": ("MODEL",),
362364
"context_opts": ("CONTEXT_OPTIONS",),
365+
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
366+
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
363367
},
364368
"optional": {
369+
"visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}),
365370
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
371+
"steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}),
366372
"start_step": ("INT", {"min": 0, "max": BIGMAX, "default": 0}),
367373
"end_step": ("INT", {"min": 1, "max": BIGMAX, "default": 20}),
368374
}
@@ -372,7 +378,65 @@ def INPUT_TYPES(s):
372378
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
373379
FUNCTION = "visualize"
374380

375-
def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup,
376-
latents_length=32, start_step=0, end_step=20):
377-
images = torch.zeros((latents_length, 256, 256, 3))
381+
def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str,
382+
visual_width: 1280, latents_length=32, steps=20, start_step=0, end_step=20):
383+
images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length,
384+
sampler_name=sampler_name, scheduler=scheduler,
385+
steps=steps, start_step=start_step, end_step=end_step)
386+
return (images,)
387+
388+
389+
class VisualizeContextOptionsK:
390+
@classmethod
391+
def INPUT_TYPES(s):
392+
return {
393+
"required": {
394+
"model": ("MODEL",),
395+
"context_opts": ("CONTEXT_OPTIONS",),
396+
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
397+
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
398+
},
399+
"optional": {
400+
"visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}),
401+
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
402+
"steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}),
403+
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
404+
}
405+
}
406+
407+
RETURN_TYPES = ("IMAGE",)
408+
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
409+
FUNCTION = "visualize"
410+
411+
def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str,
412+
visual_width: 1280, latents_length=32, steps=20, denoise=1.0):
413+
images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length,
414+
sampler_name=sampler_name, scheduler=scheduler,
415+
steps=steps, denoise=denoise)
416+
return (images,)
417+
418+
419+
class VisualizeContextOptionsSCustom:
420+
@classmethod
421+
def INPUT_TYPES(s):
422+
return {
423+
"required": {
424+
"model": ("MODEL",),
425+
"context_opts": ("CONTEXT_OPTIONS",),
426+
"sigmas": ("SIGMAS", ),
427+
},
428+
"optional": {
429+
"visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}),
430+
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
431+
}
432+
}
433+
434+
RETURN_TYPES = ("IMAGE",)
435+
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
436+
FUNCTION = "visualize"
437+
438+
def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigmas,
439+
visual_width: 1280, latents_length=32):
440+
images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length,
441+
sigmas=sigmas)
378442
return (images,)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-animatediff-evolved"
33
description = "Improved AnimateDiff integration for ComfyUI."
4-
version = "1.0.9"
4+
version = "1.0.10"
55
license = "LICENSE"
66
dependencies = []
77

0 commit comments

Comments
 (0)