Skip to content

Commit f3b24c1

Browse files
authored
Merge PR #422 from Kosinkadink/develop - Custom CFG Improvements + GPU Noise
Custom CFG Improvements + GPU Noise
2 parents 7c75983 + b74d56c commit f3b24c1

11 files changed

+525
-56
lines changed

animatediff/cfg_extras.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Union
2+
3+
import inspect
4+
import torch
5+
from torch import Tensor
6+
7+
import comfy.model_patcher
8+
import comfy.samplers
9+
10+
from .utils_motion import extend_to_batch_size, prepare_mask_batch
11+
12+
13+
################################################################################
14+
# helpers for modifying model_options to apply cfg function patches;
15+
# taken from comfy/model_patcher.py
16+
def set_model_options_sampler_cfg_function(model_options: dict[str], sampler_cfg_function, disable_cfg1_optimization=False):
17+
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
18+
model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
19+
else:
20+
model_options["sampler_cfg_function"] = sampler_cfg_function
21+
if disable_cfg1_optimization:
22+
model_options["disable_cfg1_optimization"] = True
23+
return model_options
24+
#-------------------------------------------------------------------------------
25+
26+
27+
# this is a modified version of PerturbedAttentionGuidance from comfy_extras/nodes_pag.py
28+
def perturbed_attention_guidance_patch(scale_multival: Union[float, Tensor]):
29+
unet_block = "middle"
30+
unet_block_id = 0
31+
32+
def perturbed_attention(q, k, v, extra_options, mask=None):
33+
return v
34+
35+
def post_cfg_function(args):
36+
model = args["model"]
37+
cond_pred: Tensor = args["cond_denoised"]
38+
cond = args["cond"]
39+
cfg_result = args["denoised"]
40+
sigma = args["sigma"]
41+
model_options = args["model_options"].copy()
42+
x = args["input"]
43+
44+
if type(scale_multival) != Tensor and scale_multival == 0:
45+
return cfg_result
46+
47+
scale = scale_multival
48+
if isinstance(scale, Tensor):
49+
scale = prepare_mask_batch(scale.to(cond_pred.dtype).to(cond_pred.device), cond_pred.shape)
50+
scale = extend_to_batch_size(scale, cond_pred.shape[0])
51+
52+
# Replace Self-attention with PAG
53+
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id)
54+
(pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
55+
56+
return cfg_result + (cond_pred - pag) * scale
57+
58+
return post_cfg_function
59+
60+
61+
# this is a modified version of RescaleCFG from comfy_extras/nodes_model_advanced.py
62+
def rescale_cfg_patch(multiplier_multival: Union[float, Tensor]):
63+
def cfg_function(args):
64+
cond: Tensor = args["cond"]
65+
uncond = args["uncond"]
66+
cond_scale = args["cond_scale"]
67+
sigma = args["sigma"]
68+
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
69+
x_orig = args["input"]
70+
71+
#rescale cfg has to be done on v-pred model output
72+
x = x_orig / (sigma * sigma + 1.0)
73+
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
74+
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
75+
76+
#rescalecfg
77+
x_cfg = uncond + cond_scale * (cond - uncond)
78+
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
79+
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
80+
81+
multiplier = multiplier_multival
82+
if isinstance(multiplier, Tensor):
83+
multiplier = prepare_mask_batch(multiplier.to(cond.dtype).to(cond.device), cond.shape)
84+
multiplier = extend_to_batch_size(multiplier, cond.shape[0])
85+
86+
x_rescaled = x_cfg * (ro_pos / ro_cfg)
87+
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
88+
89+
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
90+
91+
return cfg_function

animatediff/context.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Callable, Optional, Union
22

3+
import torchvision
4+
import PIL
5+
36
import numpy as np
47
from torch import Tensor
58

@@ -473,3 +476,31 @@ def shift_window_to_end(window: list[int], num_frames: int):
473476
for i in range(len(window)):
474477
# 2) add end_delta to each val to slide windows to end
475478
window[i] = window[i] + end_delta
479+
480+
481+
##########################
482+
# Context Visualization
483+
##########################
484+
class Colors:
485+
BLACK = (0, 0, 0)
486+
WHITE = (255, 255, 255)
487+
RED = (255, 0, 0)
488+
GREEN = (0, 255, 0)
489+
BLUE = (0, 0, 255)
490+
YELLOW = (255, 255, 0)
491+
MAGENTA = (255, 0, 255)
492+
CYAN = (0, 255, 255)
493+
494+
495+
class VisualizeSettings:
496+
def __init__(self, img_width, img_height, video_length):
497+
self.img_width = img_width
498+
self.img_height = img_height
499+
self.video_length = video_length
500+
self.grid = img_width // video_length
501+
self.pil_to_tensor = torchvision.transforms.Compose([torchvision.transforms.PILToTensor()])
502+
503+
504+
def generate_context_visualization(context_opts: ContextOptionsGroup, model: BaseModel, width=1440, height=200, video_length=32, start_step=0, end_step=20):
505+
vs = VisualizeSettings(width, height, video_length)
506+
pass

animatediff/motion_module_ad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ def has_img_encoder(mm_state_dict: dict[str, Tensor]):
134134

135135
def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> Tuple[dict[str, Tensor], AnimateDiffInfo]:
136136
# from pathlib import Path
137-
# with open(Path(__file__).parent.parent.parent / f"keys_{mm_name}.txt", "w") as afile:
137+
# log_name = mm_name.split('\\')[-1]
138+
# with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile:
138139
# for key, value in mm_state_dict.items():
139140
# afile.write(f"{key}:\t{value.shape}\n")
140-
141141
# determine what SD model the motion module is intended for
142142
sd_type: str = None
143143
down_block_max = get_down_block_max(mm_state_dict)

animatediff/nodes.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
ConditioningTimestepsNode, SetLoraHookKeyframes,
2121
CreateLoraHookKeyframe, CreateLoraHookKeyframeInterpolation, CreateLoraHookKeyframeFromStrengthList)
2222
from .nodes_sample import (FreeInitOptionsNode, NoiseLayerAddWeightedNode, SampleSettingsNode, NoiseLayerAddNode, NoiseLayerReplaceNode, IterationOptionsNode,
23-
CustomCFGNode, CustomCFGKeyframeNode, NoisedImageInjectionNode, NoisedImageInjectOptionsNode)
23+
CustomCFGNode, CustomCFGSimpleNode, CustomCFGKeyframeNode, CustomCFGKeyframeSimpleNode,
24+
CFGExtrasPAGNode, CFGExtrasPAGSimpleNode, CFGExtrasRescaleCFGNode, CFGExtrasRescaleCFGSimpleNode,
25+
NoisedImageInjectionNode, NoisedImageInjectOptionsNode)
2426
from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode)
2527
from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode,
26-
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode)
28+
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsInt)
2729
from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode,
2830
WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode,
2931
WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode)
30-
from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect
32+
from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect, PerturbedAttentionGuidanceMultival, RescaleCFGMultival
3133
from .nodes_deprecated import (AnimateDiffLoader_Deprecated, AnimateDiffLoaderAdvanced_Deprecated, AnimateDiffCombine_Deprecated,
3234
AnimateDiffModelSettings, AnimateDiffModelSettingsSimple, AnimateDiffModelSettingsAdvanced, AnimateDiffModelSettingsAdvancedAttnStrengths)
3335
from .nodes_lora import AnimateDiffLoraLoader
@@ -56,6 +58,7 @@
5658
"ADE_ViewsOnlyContextOptions": ViewAsContextOptionsNode,
5759
"ADE_BatchedContextOptions": BatchedContextOptionsNode,
5860
"ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy
61+
#"ADE_VisualizeContextOptions": VisualizeContextOptionsInt,
5962
# View Opts
6063
"ADE_StandardStaticViewOptions": StandardStaticViewOptionsNode,
6164
"ADE_StandardUniformViewOptions": StandardUniformViewOptionsNode,
@@ -100,7 +103,9 @@
100103
"ADE_AdjustWeightIndivAttnAdd": WeightAdjustIndivAttnAddNode,
101104
"ADE_AdjustWeightIndivAttnMult": WeightAdjustIndivAttnMultNode,
102105
# Sample Settings
106+
"ADE_CustomCFGSimple": CustomCFGSimpleNode,
103107
"ADE_CustomCFG": CustomCFGNode,
108+
"ADE_CustomCFGKeyframeSimple": CustomCFGKeyframeSimpleNode,
104109
"ADE_CustomCFGKeyframe": CustomCFGKeyframeNode,
105110
"ADE_SigmaSchedule": SigmaScheduleNode,
106111
"ADE_RawSigmaSchedule": RawSigmaScheduleNode,
@@ -109,10 +114,16 @@
109114
"ADE_SigmaScheduleSplitAndCombine": SplitAndCombineSigmaScheduleNode,
110115
"ADE_NoisedImageInjection": NoisedImageInjectionNode,
111116
"ADE_NoisedImageInjectOptions": NoisedImageInjectOptionsNode,
117+
"ADE_CFGExtrasPAGSimple": CFGExtrasPAGSimpleNode,
118+
"ADE_CFGExtrasPAG": CFGExtrasPAGNode,
119+
"ADE_CFGExtrasRescaleCFGSimple": CFGExtrasRescaleCFGSimpleNode,
120+
"ADE_CFGExtrasRescaleCFG": CFGExtrasRescaleCFGNode,
112121
# Extras Nodes
113122
"ADE_AnimateDiffUnload": AnimateDiffUnload,
114123
"ADE_EmptyLatentImageLarge": EmptyLatentImageLarge,
115124
"CheckpointLoaderSimpleWithNoiseSelect": CheckpointLoaderSimpleWithNoiseSelect,
125+
"ADE_PerturbedAttentionGuidanceMultival": PerturbedAttentionGuidanceMultival,
126+
"ADE_RescaleCFGMultival": RescaleCFGMultival,
116127
# Gen1 Nodes
117128
"ADE_AnimateDiffLoaderGen1": AnimateDiffLoaderGen1,
118129
"ADE_AnimateDiffLoaderWithContext": LegacyAnimateDiffLoaderWithContext,
@@ -158,8 +169,8 @@
158169
"ADE_AnimateDiffSamplingSettings": "Sample Settings 🎭🅐🅓",
159170
"ADE_AnimateDiffKeyframe": "AnimateDiff Keyframe 🎭🅐🅓",
160171
# Multival Nodes
161-
"ADE_MultivalDynamic": "Multival Dynamic 🎭🅐🅓",
162-
"ADE_MultivalDynamicFloatInput": "Multival Dynamic [Float List] 🎭🅐🅓",
172+
"ADE_MultivalDynamic": "Multival 🎭🅐🅓",
173+
"ADE_MultivalDynamicFloatInput": "Multival [Float List] 🎭🅐🅓",
163174
"ADE_MultivalScaledMask": "Multival Scaled Mask 🎭🅐🅓",
164175
"ADE_MultivalConvertToMask": "Multival to Mask 🎭🅐🅓",
165176
# Context Opts
@@ -169,6 +180,7 @@
169180
"ADE_ViewsOnlyContextOptions": "Context Options◆Views Only [VRAM⇈] 🎭🅐🅓",
170181
"ADE_BatchedContextOptions": "Context Options◆Batched [Non-AD] 🎭🅐🅓",
171182
"ADE_AnimateDiffUniformContextOptions": "Context Options◆Looped Uniform 🎭🅐🅓", # Legacy
183+
"ADE_VisualizeContextOptions": "Visualize Context Options 🎭🅐🅓",
172184
# View Opts
173185
"ADE_StandardStaticViewOptions": "View Options◆Standard Static 🎭🅐🅓",
174186
"ADE_StandardUniformViewOptions": "View Options◆Standard Uniform 🎭🅐🅓",
@@ -213,19 +225,27 @@
213225
"ADE_AdjustWeightIndivAttnAdd": "Adjust Weight [Indiv-Attn◆Add] 🎭🅐🅓",
214226
"ADE_AdjustWeightIndivAttnMult": "Adjust Weight [Indiv-Attn◆Mult] 🎭🅐🅓",
215227
# Sample Settings
216-
"ADE_CustomCFG": "Custom CFG 🎭🅐🅓",
217-
"ADE_CustomCFGKeyframe": "Custom CFG Keyframe 🎭🅐🅓",
228+
"ADE_CustomCFGSimple": "Custom CFG 🎭🅐🅓",
229+
"ADE_CustomCFG": "Custom CFG [Multival] 🎭🅐🅓",
230+
"ADE_CustomCFGKeyframeSimple": "Custom CFG Keyframe 🎭🅐🅓",
231+
"ADE_CustomCFGKeyframe": "Custom CFG Keyframe [Multival] 🎭🅐🅓",
218232
"ADE_SigmaSchedule": "Create Sigma Schedule 🎭🅐🅓",
219233
"ADE_RawSigmaSchedule": "Create Raw Sigma Schedule 🎭🅐🅓",
220234
"ADE_SigmaScheduleWeightedAverage": "Sigma Schedule Weighted Mean 🎭🅐🅓",
221235
"ADE_SigmaScheduleWeightedAverageInterp": "Sigma Schedule Interpolated Mean 🎭🅐🅓",
222236
"ADE_SigmaScheduleSplitAndCombine": "Sigma Schedule Split Combine 🎭🅐🅓",
223237
"ADE_NoisedImageInjection": "Image Injection 🎭🅐🅓",
224238
"ADE_NoisedImageInjectOptions": "Image Injection Options 🎭🅐🅓",
239+
"ADE_CFGExtrasPAGSimple": "CFG Extras◆PAG 🎭🅐🅓",
240+
"ADE_CFGExtrasPAG": "CFG Extras◆PAG [Multival] 🎭🅐🅓",
241+
"ADE_CFGExtrasRescaleCFGSimple": "CFG Extras◆RescaleCFG 🎭🅐🅓",
242+
"ADE_CFGExtrasRescaleCFG": "CFG Extras◆RescaleCFG [Multival] 🎭🅐🅓",
225243
# Extras Nodes
226244
"ADE_AnimateDiffUnload": "AnimateDiff Unload 🎭🅐🅓",
227245
"ADE_EmptyLatentImageLarge": "Empty Latent Image (Big Batch) 🎭🅐🅓",
228246
"CheckpointLoaderSimpleWithNoiseSelect": "Load Checkpoint w/ Noise Select 🎭🅐🅓",
247+
"ADE_PerturbedAttentionGuidanceMultival": "PerturbedAttnGuide [Multival] 🎭🅐🅓",
248+
"ADE_RescaleCFGMultival": "RescaleCFG [Multival] 🎭🅐🅓",
229249
# Gen1 Nodes
230250
"ADE_AnimateDiffLoaderGen1": "AnimateDiff Loader 🎭🅐🅓①",
231251
"ADE_AnimateDiffLoaderWithContext": "AnimateDiff Loader [Legacy] 🎭🅐🅓①",

animatediff/nodes_context.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import torch
2+
from torch import Tensor
3+
4+
from comfy.model_patcher import ModelPatcher
5+
16
from .context import ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules
27
from .utils_model import BIGMAX
38

@@ -346,3 +351,28 @@ def create_options(self, view_length: int, view_overlap: int, view_stride: int,
346351
use_on_equal_length=use_on_equal_length,
347352
)
348353
return (view_options,)
354+
355+
356+
class VisualizeContextOptionsInt:
357+
@classmethod
358+
def INPUT_TYPES(s):
359+
return {
360+
"required": {
361+
"model": ("MODEL",),
362+
"context_opts": ("CONTEXT_OPTIONS",),
363+
},
364+
"optional": {
365+
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
366+
"start_step": ("INT", {"min": 0, "max": BIGMAX, "default": 0}),
367+
"end_step": ("INT", {"min": 1, "max": BIGMAX, "default": 20}),
368+
}
369+
}
370+
371+
RETURN_TYPES = ("IMAGE",)
372+
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
373+
FUNCTION = "visualize"
374+
375+
def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup,
376+
latents_length=32, start_step=0, end_step=20):
377+
images = torch.zeros((latents_length, 256, 256, 3))
378+
return (images,)

animatediff/nodes_deprecated.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def INPUT_TYPES(s):
292292
},
293293
"optional": {
294294
"mask_motion_scale": ("MASK",),
295-
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
295+
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
296296
}
297297
}
298298

@@ -321,7 +321,7 @@ def INPUT_TYPES(s):
321321
"mask_motion_scale": ("MASK",),
322322
"min_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
323323
"max_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
324-
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
324+
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
325325
}
326326
}
327327

@@ -360,7 +360,7 @@ def INPUT_TYPES(s):
360360
"mask_motion_scale": ("MASK",),
361361
"min_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
362362
"max_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
363-
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
363+
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
364364
}
365365
}
366366

@@ -415,7 +415,7 @@ def INPUT_TYPES(s):
415415
"mask_motion_scale": ("MASK",),
416416
"min_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
417417
"max_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
418-
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
418+
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
419419
}
420420
}
421421

animatediff/nodes_extras.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1+
from typing import Union
2+
13
import torch
4+
from torch import Tensor
25

36
import folder_paths
47
import nodes as comfy_nodes
58
from comfy.model_patcher import ModelPatcher
9+
import comfy.model_patcher
10+
import comfy.samplers
611
from comfy.sd import load_checkpoint_guess_config
712

813
from .logger import logger
914
from .utils_model import BetaSchedules
15+
from .utils_motion import extend_to_batch_size, prepare_mask_batch
1016
from .model_injection import get_vanilla_model_patcher
17+
from .cfg_extras import perturbed_attention_guidance_patch, rescale_cfg_patch
1118

1219

1320
class AnimateDiffUnload:
@@ -76,3 +83,46 @@ def INPUT_TYPES(s):
7683
def generate(self, width, height, batch_size=1):
7784
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
7885
return ({"samples":latent}, )
86+
87+
88+
class PerturbedAttentionGuidanceMultival:
89+
@classmethod
90+
def INPUT_TYPES(s):
91+
return {
92+
"required": {
93+
"model": ("MODEL",),
94+
"scale_multival": ("MULTIVAL",),
95+
}
96+
}
97+
98+
RETURN_TYPES = ("MODEL",)
99+
FUNCTION = "patch"
100+
101+
CATEGORY = "Animate Diff 🎭🅐🅓/extras"
102+
103+
def patch(self, model: ModelPatcher, scale_multival: Union[float, Tensor]):
104+
m = model.clone()
105+
m.set_model_sampler_post_cfg_function(perturbed_attention_guidance_patch(scale_multival))
106+
107+
return (m,)
108+
109+
110+
class RescaleCFGMultival:
111+
@classmethod
112+
def INPUT_TYPES(s):
113+
return {
114+
"required": {
115+
"model": ("MODEL",),
116+
"mult_multival": ("MULTIVAL",),
117+
}
118+
}
119+
120+
RETURN_TYPES = ("MODEL",)
121+
FUNCTION = "patch"
122+
123+
CATEGORY = "Animate Diff 🎭🅐🅓/extras"
124+
125+
def patch(self, model: ModelPatcher, mult_multival: Union[float, Tensor]):
126+
m = model.clone()
127+
m.set_model_sampler_cfg_function(rescale_cfg_patch(mult_multival))
128+
return (m, )

0 commit comments

Comments
 (0)