1
- from typing import Callable , Optional , Union
1
+ from typing import Union
2
2
3
+ import torch
3
4
import torchvision
4
- import PIL
5
+ from PIL import Image , ImageFont , ImageDraw
5
6
6
7
import numpy as np
7
8
from torch import Tensor
8
9
10
+ import comfy .samplers
9
11
from comfy .model_base import BaseModel
12
+ from comfy .model_patcher import ModelPatcher
10
13
11
14
from .utils_motion import get_sorted_list_via_attr
12
15
@@ -76,7 +79,7 @@ def __init__(self):
76
79
self ._current_context : ContextOptions = None
77
80
self ._current_used_steps : int = 0
78
81
self ._current_index : int = 0
79
- self .step = 0
82
+ self ._step = 0
80
83
81
84
def reset (self ):
82
85
self ._current_context = None
@@ -85,6 +88,15 @@ def reset(self):
85
88
self .step = 0
86
89
self ._set_first_as_current ()
87
90
91
+ @property
92
+ def step (self ):
93
+ return self ._step
94
+ @step .setter
95
+ def step (self , value : int ):
96
+ self ._step = value
97
+ if self ._current_context is not None :
98
+ self ._current_context .step = value
99
+
88
100
@classmethod
89
101
def default (cls ):
90
102
def_context = ContextOptions ()
@@ -492,15 +504,176 @@ class Colors:
492
504
CYAN = (0 , 255 , 255 )
493
505
494
506
507
+ class BorderWidth :
508
+ INDEXES = 2
509
+ CONTEXT = 4
510
+
511
+
495
512
class VisualizeSettings :
496
- def __init__ (self , img_width , img_height , video_length ):
497
- self .img_width = img_width
498
- self .img_height = img_height
513
+ def __init__ (self , img_width : int , video_length : int ):
499
514
self .video_length = video_length
515
+ self .img_width = img_width
500
516
self .grid = img_width // video_length
517
+ self .img_height = self .grid * 5
501
518
self .pil_to_tensor = torchvision .transforms .Compose ([torchvision .transforms .PILToTensor ()])
519
+ self .font_size = int (self .grid * 0.5 )
520
+ self .font = ImageFont .load_default (size = self .font_size )
521
+ #self.title_font = ImageFont.load_default(size=int(self.font_size * 1.5))
522
+ self .title_font = ImageFont .load_default (size = int (self .font_size * 1.2 ))
502
523
524
+ self .background_color = Colors .BLACK
525
+ self .grid_outline_color = Colors .WHITE
526
+ self .start_idx_fill_color = Colors .MAGENTA
527
+ self .subidx_end_color = Colors .YELLOW
503
528
504
- def generate_context_visualization (context_opts : ContextOptionsGroup , model : BaseModel , width = 1440 , height = 200 , video_length = 32 , start_step = 0 , end_step = 20 ):
505
- vs = VisualizeSettings (width , height , video_length )
506
- pass
529
+ self .context_color = Colors .GREEN
530
+ self .view_color = Colors .RED
531
+
532
+
533
+ class GridDisplay :
534
+ def __init__ (self , draw : ImageDraw .ImageDraw , vs : VisualizeSettings , home_x : int = 0 , home_y : int = 0 ):
535
+ self .home_x = home_x
536
+ self .home_y = home_y
537
+ self .draw = draw
538
+ self .vs = vs
539
+
540
+
541
+ def get_text_xy (input : str , font : ImageFont , x : int , y : int , centered = True ):
542
+ return (x , y ,)
543
+
544
+
545
+ def draw_text (text : str , font : ImageFont , gd : GridDisplay , x : int , y : int , color = Colors .WHITE , centered = True ):
546
+ x , y = get_text_xy (text , font , x , y , centered = centered )
547
+ gd .draw .text (xy = (gd .home_x + x , gd .home_y + y ), text = text , fill = color , font = font )
548
+
549
+
550
+ def draw_first_grid_row (total_length : int , gd : GridDisplay , start_idx = - 1 ):
551
+ vs = gd .vs
552
+ # the first row is white squares, with the indexes drawed in
553
+ for i in range (total_length ):
554
+ x1 = gd .home_x + (vs .grid * i )
555
+ y1 = gd .home_y
556
+ x2 = x1 + vs .grid
557
+ y2 = y1 + vs .grid
558
+
559
+ fill = None
560
+ if i == start_idx :
561
+ fill = vs .start_idx_fill_color
562
+ gd .draw .rectangle (xy = (x1 , y1 , x2 , y2 ), fill = fill , outline = vs .grid_outline_color , width = BorderWidth .INDEXES )
563
+ draw_text (text = str (i ), font = vs .font , gd = gd , x = vs .grid * i , y = 0 )
564
+
565
+
566
+ def draw_subidxs (window : list [int ], gd : GridDisplay , y_grid_offset : int , color : tuple ):
567
+ vs = gd .vs
568
+ # with no indexes drawed in- just solid squares, mostly
569
+ y_offset = vs .grid * y_grid_offset
570
+ for i , val in enumerate (window ):
571
+ x1 = gd .home_x + (vs .grid * val )
572
+ y1 = gd .home_y + y_offset
573
+ x2 = x1 + vs .grid
574
+ y2 = y1 + vs .grid
575
+ fill_color = color
576
+ # if at an end of indexes, make inside be different color
577
+ if i == 0 or i == len (window )- 1 :
578
+ fill_color = vs .subidx_end_color
579
+ gd .draw .rectangle (xy = (x1 , y1 , x2 , y2 ), fill = fill_color , outline = color , width = BorderWidth .CONTEXT )
580
+
581
+
582
+ def draw_context (window : list [int ], gd : GridDisplay ):
583
+ draw_subidxs (window = window , gd = gd , y_grid_offset = 1 , color = gd .vs .context_color )
584
+
585
+
586
+ def draw_view (window : list [int ], gd : GridDisplay ):
587
+ draw_subidxs (window = window , gd = gd , y_grid_offset = 2 , color = gd .vs .view_color )
588
+
589
+
590
+ def generate_context_visualization (context_opts : ContextOptionsGroup , model : ModelPatcher , sampler_name : str = None , scheduler : str = None ,
591
+ width = 1440 , height = 200 , video_length = 32 ,
592
+ steps = None , start_step = None , end_step = None , sigmas = None , force_full_denoise = False , denoise = None ):
593
+ context_opts = context_opts .clone ()
594
+ vs = VisualizeSettings (width , video_length )
595
+ all_imgs = []
596
+
597
+ if sigmas is None :
598
+ sampler = comfy .samplers .KSampler (
599
+ model = model , steps = steps , device = "cpu" , sampler = sampler_name , scheduler = scheduler ,
600
+ denoise = denoise , model_options = model .model_options ,
601
+ )
602
+ sigmas = sampler .sigmas
603
+ if end_step is not None and end_step < (len (sigmas ) - 1 ):
604
+ sigmas = sigmas [:end_step + 1 ]
605
+ if force_full_denoise :
606
+ sigmas [- 1 ] = 0
607
+ if start_step is not None :
608
+ if start_step < (len (sigmas ) - 1 ):
609
+ sigmas = sigmas [start_step :]
610
+ # remove last sigma, as sampling uses pairs of sigmas at a time (fence post problem)
611
+ sigmas = sigmas [:- 1 ]
612
+
613
+ context_opts .reset ()
614
+ context_opts .initialize_timesteps (model .model )
615
+
616
+ if start_step is None :
617
+ start_step = 0 # use this in case start_step is provided, to display accurate step
618
+ if steps is None :
619
+ steps = len (sigmas )
620
+
621
+ for i , t in enumerate (sigmas ):
622
+ # make context_opts reflect current step/sigma
623
+ context_opts .prepare_current_context ([t ])
624
+ context_opts .step = start_step + i
625
+
626
+ # check if context should even be active in this case
627
+ context_active = True
628
+ if video_length < context_opts .context_length :
629
+ context_active = False
630
+ elif video_length == context_opts .context_length and not context_opts .use_on_equal_length :
631
+ context_active = False
632
+
633
+ if context_active :
634
+ context_windows = get_context_windows (num_frames = video_length , opts = context_opts )
635
+ else :
636
+ context_windows = [list (range (video_length ))]
637
+ start_idx = - 1
638
+ for j ,window in enumerate (context_windows ):
639
+ repeat_count = 0
640
+ view_windows = []
641
+ total_repeats = 1
642
+ view_options = context_opts .view_options
643
+ if view_options is not None :
644
+ view_active = True
645
+ if len (window ) < view_options .context_length :
646
+ view_active = False
647
+ elif video_length == view_options .context_length and not view_options .use_on_equal_length :
648
+ view_active = False
649
+ if view_active :
650
+ view_windows = get_context_windows (num_frames = len (window ), opts = view_options )
651
+ total_repeats = len (view_windows )
652
+ while total_repeats > repeat_count :
653
+ # create new frame
654
+ frame : Image = Image .new (mode = "RGB" , size = (vs .img_width , vs .img_height ), color = vs .background_color )
655
+ draw = ImageDraw .Draw (frame )
656
+ gd = GridDisplay (draw = draw , vs = vs , home_x = 0 , home_y = vs .grid )
657
+ # if views present, do view stuff
658
+ if len (view_windows ) > 0 :
659
+ converted_view = [window [x ] for x in view_windows [repeat_count ]]
660
+ draw_view (window = converted_view , gd = gd )
661
+ # draw context_type + current step
662
+ title_str = f"{ context_opts .context_schedule } - Step { context_opts .step + 1 } /{ steps } (Context { j + 1 } /{ len (context_windows )} )"
663
+ if len (view_windows ) > 0 :
664
+ title_str = f"{ title_str } (View { repeat_count + 1 } /{ len (view_windows )} )"
665
+ draw_text (text = title_str , font = vs .title_font , gd = gd , x = 0 - gd .home_x , y = 0 - gd .home_y , centered = False )
666
+ # draw first row (total length, white)
667
+ if j == 0 :
668
+ start_idx = window [0 ]
669
+ draw_first_grid_row (total_length = video_length , gd = gd , start_idx = start_idx )
670
+ # draw context row
671
+ draw_context (window = window , gd = gd )
672
+ # save image + iterate repeat_count
673
+ img : Tensor = vs .pil_to_tensor (frame )
674
+ all_imgs .append (img )
675
+ repeat_count += 1
676
+
677
+ images = torch .stack (all_imgs )
678
+ images = images .movedim (1 , - 1 ).to (torch .float32 )
679
+ return images
0 commit comments