Skip to content

Commit c98e8e6

Browse files
authored
Merge PR #405 from Kosinkadink/develop - PIA, Image Injection, Multival Expansion
PIA, Image Injection, and Multival Expansion
2 parents 379044e + 528ae52 commit c98e8e6

13 files changed

+1213
-90
lines changed

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,26 @@ NOTE: you can also use custom locations for models/motion loras by making use of
6565
- NOTE: Requires same settings as described for AnimateLCM above. Requires ```Apply AnimateLCM-I2V Model``` Gen2 node usage so that ```ref_latent``` can be provided; use ```Scale Ref Image and VAE Encode``` node to preprocess input images. While this was intended as an img2video model, I found it works best for vid2vid purposes with ```ref_drift=0.0```, and to use it for only at least 1 step before switching over to other models via chaining with toher Apply AnimateDiff Model (Adv.) nodes. The ```apply_ref_when_disabled``` can be set to True to allow the img_encoder to do its thing even when the ```end_percent``` is reached. AnimateLCM-I2V is also extremely useful for maintaining coherence at higher resolutions (with ControlNet and SD LoRAs active, I could easily upscale from 512x512 source to 1024x1024 in a single pass). TODO: add examples
6666
- [CameraCtrl](https://github.com/hehao13/CameraCtrl) support, with the pruned model you must use here: [CameraCtrl_pruned.safetensors](https://huggingface.co/Kosinkadink/CameraCtrl/tree/main)
6767
- NOTE: Requires AnimateDiff SD1.5 models, and was specifically trained for v3 model. Gen2 only, with helper nodes provided under Gen2/CameraCtrl submenu.
68+
- [PIA](https://github.com/open-mmlab/PIA) support, with the model [pia.ckpt](https://huggingface.co/Leoxing/PIA/tree/main)
69+
- NOTE: You will need to use ```autoselect``` or ```sqrt_linear (AnimateDiff)``` beta_schedule. Requires ```Apply AnimateDiff-PIA Model``` Gen2 node usage if you want to actually provide input images. The ```pia_input``` can be provided via the paper's presets (```PIA Input [Paper Presets]```) or by manually entering values (```PIA Input [Multival]```).
6870
- AnimateDiff Keyframes to change Scale and Effect at different points in the sampling process.
6971
- fp8 support; requires newest ComfyUI and torch >= 2.1 (decreases VRAM usage, but changes outputs)
7072
- Mac M1/M2/M3 support
7173
- Usage of Context Options and Sample Settings outside of AnimateDiff via Gen2 Use Evolved Sampling node
7274
- Maskable and Schedulable SD LoRA (and Models as LoRA) for both AnimateDiff and StableDiffusion usage via LoRA Hooks
7375
- Per-frame GLIGEN coordinates control
7476
- Currently requires GLIGENTextBoxApplyBatch from KJNodes to do so, but I will add native nodes to do this soon.
77+
- Image Injection mid-sampling
7578

7679
## Upcoming Features
77-
- Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: mid-May)
78-
- Maskable Motion LoRA (Goal: end of May/beginning of June)
80+
- Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: before Elden Ring DLC releases. Working on it right now.)
81+
- [UniCtrl](https://github.com/XuweiyiChen/UniCtrl) support
82+
- Unet-Ref support so that a bunch of papers can be ported over
83+
- [StoryDiffusion](https://github.com/HVision-NKU/StoryDiffusion) implementation
84+
- Merging motion model weights/components, including per block customization
85+
- Maskable Motion LoRA
7986
- Timestep schedulable GLIGEN coordinates
8087
- Dynamic memory management for motion models that load/unload at different start/end_percents
81-
- [PIA](https://github.com/open-mmlab/PIA) support
82-
- [UniCtrl](https://github.com/XuweiyiChen/UniCtrl) support
8388
- Built-in prompt travel implementation
8489
- Anything else AnimateDiff-related that comes out
8590

animatediff/model_injection.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,26 @@
88
import uuid
99
import math
1010

11+
import comfy.conds
1112
import comfy.lora
1213
import comfy.model_management
1314
import comfy.utils
1415
from comfy.model_patcher import ModelPatcher
1516
from comfy.model_base import BaseModel
16-
from comfy.sd import CLIP
17+
from comfy.sd import CLIP, VAE
1718

1819
from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight
1920
from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding
2021
from .context import ContextOptions, ContextOptions, ContextOptionsGroup
2122
from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention,
2223
has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len)
2324
from .logger import logger
24-
from .utils_motion import ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, get_combined_multival, ade_broadcast_image_to, normalize_min_max
25+
from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA,
26+
get_combined_multival, get_combined_input, get_combined_input_effect_multival,
27+
ade_broadcast_image_to, extend_to_batch_size, prepare_mask_batch)
2528
from .conditioning import HookRef, LoraHook, LoraHookGroup, LoraHookMode
2629
from .motion_lora import MotionLoraInfo, MotionLoraList
27-
from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type
30+
from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched
2831
from .sample_settings import SampleSettings, SeedNoiseGeneration
2932

3033

@@ -138,7 +141,6 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
138141
'''
139142
Based on add_patches, but for hooked weights.
140143
'''
141-
# TODO: make this work with timestep scheduling
142144
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
143145
p = set()
144146
model_sd = self.model.state_dict()
@@ -164,7 +166,6 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
164166
'''
165167
Based on add_hooked_patches, but intended for using a model's weights as lora hook.
166168
'''
167-
# TODO: make this work with timestep scheduling
168169
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
169170
p = set()
170171
model_sd = self.model.state_dict()
@@ -180,6 +181,7 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
180181
p.add(k)
181182
current_patches: list[tuple] = current_hooked_patches.get(key, [])
182183
# take difference between desired weight and existing weight to get diff
184+
# TODO: create fix for fp8
183185
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
184186
current_hooked_patches[key] = current_patches
185187
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
@@ -238,7 +240,7 @@ def patch_model_lowvram(self, *args, **kwargs):
238240
self.model_params_lowvram_keys[f"{n}.weight"] = n
239241
if getattr(m, "bias_function", None) is not None:
240242
self.model_params_lowvram = True
241-
self.model_params_lowvram_keys[f"{n}.weight"] = n
243+
self.model_params_lowvram_keys[f"{n}.bias"] = n
242244

243245
def unpatch_model(self, device_to=None, unpatch_weights=True):
244246
# first, eject motion model from unet
@@ -721,7 +723,16 @@ def __init__(self, *args, **kwargs):
721723
self.orig_camera_entries: list[CameraEntry] = None
722724
self.camera_features: list[Tensor] = None # temporary
723725
self.camera_features_shape: tuple = None
724-
self.cameractrl_multival = None
726+
self.cameractrl_multival: Union[float, Tensor] = None
727+
728+
# PIA
729+
self.orig_pia_images: Tensor = None
730+
self.pia_vae: VAE = None
731+
self.pia_input: InputPIA = None
732+
self.cached_pia_c_concat: comfy.conds.CONDNoiseShape = None # cached
733+
self.prev_pia_latents_shape: tuple = None
734+
self.prev_current_pia_input: InputPIA = None
735+
self.pia_multival: Union[float, Tensor] = None
725736

726737
# temporary variables
727738
self.current_used_steps = 0
@@ -730,9 +741,12 @@ def __init__(self, *args, **kwargs):
730741
self.current_scale: Union[float, Tensor] = None
731742
self.current_effect: Union[float, Tensor] = None
732743
self.current_cameractrl_effect: Union[float, Tensor] = None
744+
self.current_pia_input: InputPIA = None
733745
self.combined_scale: Union[float, Tensor] = None
734746
self.combined_effect: Union[float, Tensor] = None
735747
self.combined_cameractrl_effect: Union[float, Tensor] = None
748+
self.combined_pia_mask: Union[float, Tensor] = None
749+
self.combined_pia_effect: Union[float, Tensor] = None
736750
self.was_within_range = False
737751
self.prev_sub_idxs = None
738752
self.prev_batched_number = None
@@ -774,7 +788,7 @@ def initialize_timesteps(self, model: BaseModel):
774788
for keyframe in self.keyframes.keyframes:
775789
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
776790

777-
def prepare_current_keyframe(self, t: Tensor):
791+
def prepare_current_keyframe(self, x: Tensor, t: Tensor):
778792
curr_t: float = t[0]
779793
prev_index = self.current_index
780794
# if met guaranteed steps, look for next keyframe in case need to switch
@@ -802,6 +816,10 @@ def prepare_current_keyframe(self, t: Tensor):
802816
self.current_cameractrl_effect = self.current_keyframe.cameractrl_multival
803817
elif not self.current_keyframe.inherit_missing:
804818
self.current_cameractrl_effect = None
819+
if self.current_keyframe.has_pia_input():
820+
self.current_pia_input = self.current_keyframe.pia_input
821+
elif not self.current_keyframe.inherit_missing:
822+
self.current_pia_input = None
805823
# if guarantee_steps greater than zero, stop searching for other keyframes
806824
if self.current_keyframe.guarantee_steps > 0:
807825
break
@@ -814,6 +832,8 @@ def prepare_current_keyframe(self, t: Tensor):
814832
self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale)
815833
self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect)
816834
self.combined_cameractrl_effect = get_combined_multival(self.cameractrl_multival, self.current_cameractrl_effect)
835+
self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x)
836+
self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input)
817837
# apply scale and effect
818838
self.model.set_scale(self.combined_scale)
819839
self.model.set_effect(self.combined_effect)
@@ -889,6 +909,72 @@ def prepare_camera_features(self, x: Tensor, cond_or_uncond: list[int], ad_param
889909
self.prev_sub_idxs = sub_idxs
890910
self.prev_batched_number = batched_number
891911

912+
def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor:
913+
# if have cached shape, check if matches - if so, return cached pia_latents
914+
if self.prev_pia_latents_shape is not None:
915+
if self.prev_pia_latents_shape[0] == x.shape[0] and self.prev_pia_latents_shape[2] == x.shape[2] and self.prev_pia_latents_shape[3] == x.shape[3]:
916+
# if mask is also the same for this timestep, then return cached
917+
if self.prev_current_pia_input == self.current_pia_input:
918+
return self.cached_pia_c_concat
919+
# otherwise, adjust new mask, and create new cached_pia_c_concat
920+
b, c, h ,w = x.shape
921+
mask = prepare_mask_batch(self.combined_pia_mask, x.shape)
922+
mask = extend_to_batch_size(mask, b)
923+
# make sure to update prev_current_pia_input to know when is changed
924+
self.prev_current_pia_input = self.current_pia_input
925+
# TODO: handle self.combined_pia_effect eventually (feature hidden for now)
926+
# the first index in dim=1 is the mask that needs to be updated - update in place
927+
self.cached_pia_c_concat.cond[:, :1, :, :] = mask
928+
return self.cached_pia_c_concat
929+
self.prev_pia_latents_shape = None
930+
# otherwise, x shape should be the cached pia_latents_shape
931+
# get currently used models so they can be properly reloaded after perfoming VAE Encoding
932+
if hasattr(comfy.model_management, "loaded_models"):
933+
cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
934+
else:
935+
cached_loaded_models: list[ModelPatcherAndInjector] = [x.model for x in comfy.model_management.current_loaded_models]
936+
try:
937+
b, c, h ,w = x.shape
938+
usable_ref = self.orig_pia_images[:b]
939+
# in diffusers, the image is scaled from [-1, 1] instead of default [0, 1],
940+
# but form my testing, that blows out the images here, so I skip it
941+
# usable_images = usable_images * 2 - 1
942+
# resize images to latent's dims
943+
usable_ref = usable_ref.movedim(-1,1)
944+
usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.pia_vae.downscale_ratio, height=h*self.pia_vae.downscale_ratio,
945+
upscale_method="bilinear", crop="center")
946+
usable_ref = usable_ref.movedim(1,-1)
947+
# VAE encode images
948+
logger.info("VAE Encoding PIA input images...")
949+
usable_ref = model.process_latent_in(vae_encode_raw_batched(vae=self.pia_vae, pixels=usable_ref, show_pbar=False))
950+
logger.info("VAE Encoding PIA input images complete.")
951+
# make pia_latents match expected length
952+
usable_ref = extend_to_batch_size(usable_ref, b)
953+
self.prev_pia_latents_shape = x.shape
954+
# now, take care of the mask
955+
mask = prepare_mask_batch(self.combined_pia_mask, x.shape)
956+
mask = extend_to_batch_size(mask, b)
957+
#mask = mask.unsqueeze(1)
958+
self.prev_current_pia_input = self.current_pia_input
959+
if type(self.combined_pia_effect) == Tensor or not math.isclose(self.combined_pia_effect, 1.0):
960+
real_pia_effect = self.combined_pia_effect
961+
if type(self.combined_pia_effect) == Tensor:
962+
real_pia_effect = extend_to_batch_size(prepare_mask_batch(self.combined_pia_effect, x.shape), b)
963+
zero_mask = torch.zeros_like(mask)
964+
mask = mask * real_pia_effect + zero_mask * (1.0 - real_pia_effect)
965+
del zero_mask
966+
zero_usable_ref = torch.zeros_like(usable_ref)
967+
usable_ref = usable_ref * real_pia_effect + zero_usable_ref * (1.0 - real_pia_effect)
968+
del zero_usable_ref
969+
# cache pia c_concat
970+
self.cached_pia_c_concat = comfy.conds.CONDNoiseShape(torch.cat([mask, usable_ref], dim=1))
971+
return self.cached_pia_c_concat
972+
finally:
973+
comfy.model_management.load_models_gpu(cached_loaded_models)
974+
975+
def is_pia(self):
976+
return self.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None
977+
892978
def cleanup(self):
893979
if self.model is not None:
894980
self.model.cleanup()
@@ -900,6 +986,9 @@ def cleanup(self):
900986
del self.camera_features
901987
self.camera_features = None
902988
self.camera_features_shape = None
989+
# PIA
990+
self.combined_pia_mask = None
991+
self.combined_pia_effect = None
903992
# Default
904993
self.current_used_steps = 0
905994
self.current_keyframe = None
@@ -943,6 +1032,11 @@ def clone(self):
9431032
# CameraCtrl
9441033
n.orig_camera_entries = self.orig_camera_entries
9451034
n.cameractrl_multival = self.cameractrl_multival
1035+
# PIA
1036+
n.orig_pia_images = self.orig_pia_images
1037+
n.pia_vae = self.pia_vae
1038+
n.pia_input = self.pia_input
1039+
n.pia_multival = self.pia_multival
9461040
return n
9471041

9481042

@@ -995,9 +1089,16 @@ def cleanup(self):
9951089
for motion_model in self.models:
9961090
motion_model.cleanup()
9971091

998-
def prepare_current_keyframe(self, t: Tensor):
1092+
def prepare_current_keyframe(self, x: Tensor, t: Tensor):
1093+
for motion_model in self.models:
1094+
motion_model.prepare_current_keyframe(x=x, t=t)
1095+
1096+
def get_pia_models(self):
1097+
pia_motion_models: list[MotionModelPatcher] = []
9991098
for motion_model in self.models:
1000-
motion_model.prepare_current_keyframe(t=t)
1099+
if motion_model.is_pia():
1100+
pia_motion_models.append(motion_model)
1101+
return pia_motion_models
10011102

10021103
def get_name_string(self, show_version=False):
10031104
identifiers = []
@@ -1161,6 +1262,14 @@ def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: M
11611262
motion_model.model.img_encoder.load_state_dict(w_encoder.model.img_encoder.state_dict())
11621263

11631264

1265+
def inject_pia_conv_in_into_model(motion_model: MotionModelPatcher, w_pia: MotionModelPatcher):
1266+
motion_model.model.init_conv_in(w_pia.model.state_dict())
1267+
motion_model.model.conv_in.to(comfy.model_management.unet_dtype())
1268+
motion_model.model.conv_in.to(comfy.model_management.unet_offload_device())
1269+
motion_model.model.conv_in.load_state_dict(w_pia.model.conv_in.state_dict())
1270+
motion_model.model.mm_info.mm_format = AnimateDiffFormat.PIA
1271+
1272+
11641273
def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ctrl_name: str):
11651274
camera_ctrl_path = get_motion_model_path(camera_ctrl_name)
11661275
full_state_dict = comfy.utils.load_torch_file(camera_ctrl_path, safe_load=True)

0 commit comments

Comments
 (0)