Skip to content

Commit 8f9d582

Browse files
authored
Merge pull request #403 from Kosinkadink/hook_update_fix
Fix Lora Hooks for latest ComfyUI versions
2 parents e2313c4 + 89e449d commit 8f9d582

File tree

1 file changed

+50
-22
lines changed

1 file changed

+50
-22
lines changed

animatediff/model_injection.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,19 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
141141
# TODO: make this work with timestep scheduling
142142
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
143143
p = set()
144-
for key in patches:
145-
model_sd = self.model.state_dict()
144+
model_sd = self.model.state_dict()
145+
for k in patches:
146+
offset = None
147+
if isinstance(k, str):
148+
key = k
149+
else:
150+
offset = k[1]
151+
key = k[0]
152+
146153
if key in model_sd:
147-
p.add(key)
154+
p.add(k)
148155
current_patches: list[tuple] = current_hooked_patches.get(key, [])
149-
current_patches.append((strength_patch, patches[key], strength_model))
156+
current_patches.append((strength_patch, patches[k], strength_model, offset))
150157
current_hooked_patches[key] = current_patches
151158
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
152159
# since should care about these patches too to determine if same model, reroll patches_uuid
@@ -160,13 +167,20 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
160167
# TODO: make this work with timestep scheduling
161168
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
162169
p = set()
163-
for key in patches:
164-
model_sd = self.model.state_dict()
170+
model_sd = self.model.state_dict()
171+
for k in patches:
172+
offset = None
173+
if isinstance(k, str):
174+
key = k
175+
else:
176+
offset = k[1]
177+
key = k[0]
178+
165179
if key in model_sd:
166-
p.add(key)
180+
p.add(k)
167181
current_patches: list[tuple] = current_hooked_patches.get(key, [])
168182
# take difference between desired weight and existing weight to get diff
169-
current_patches.append((strength_patch, (patches[key]-comfy.utils.get_attr(self.model, key),), strength_model))
183+
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
170184
current_hooked_patches[key] = current_patches
171185
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
172186
# since should care about these patches too to determine if same model, reroll patches_uuid
@@ -433,7 +447,7 @@ def __init__(self, m: ModelPatcher):
433447
if hasattr(m, "object_patches_backup"):
434448
self.object_patches_backup = m.object_patches_backup
435449
# lora hook stuff
436-
self.hooked_patches = {} # binds LoraHook to specific keys
450+
self.hooked_patches: dict[HookRef] = {} # binds LoraHook to specific keys
437451
self.patches_backup = {}
438452
self.hooked_backup: dict[str, tuple[Tensor, torch.device]] = {}
439453

@@ -485,16 +499,23 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
485499
'''
486500
Based on add_patches, but for hooked weights.
487501
'''
488-
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {})
502+
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
489503
p = set()
490-
for key in patches:
491-
model_sd = self.model.state_dict()
504+
model_sd = self.model.state_dict()
505+
for k in patches:
506+
offset = None
507+
if isinstance(k, str):
508+
key = k
509+
else:
510+
offset = k[1]
511+
key = k[0]
512+
492513
if key in model_sd:
493-
p.add(key)
514+
p.add(k)
494515
current_patches: list[tuple] = current_hooked_patches.get(key, [])
495-
current_patches.append((strength_patch, patches[key], strength_model))
516+
current_patches.append((strength_patch, patches[k], strength_model, offset))
496517
current_hooked_patches[key] = current_patches
497-
self.hooked_patches[lora_hook] = current_hooked_patches
518+
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
498519
# since should care about these patches too to determine if same model, reroll patches_uuid
499520
self.patches_uuid = uuid.uuid4()
500521
return list(p)
@@ -503,17 +524,24 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat
503524
'''
504525
Based on add_hooked_patches, but intended for using a model's weights as lora hook.
505526
'''
506-
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {})
527+
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
507528
p = set()
508-
for key in patches:
509-
model_sd = self.model.state_dict()
529+
model_sd = self.model.state_dict()
530+
for k in patches:
531+
offset = None
532+
if isinstance(k, str):
533+
key = k
534+
else:
535+
offset = k[1]
536+
key = k[0]
537+
510538
if key in model_sd:
511-
p.add(key)
539+
p.add(k)
512540
current_patches: list[tuple] = current_hooked_patches.get(key, [])
513541
# take difference between desired weight and existing weight to get diff
514-
current_patches.append((strength_patch, (patches[key]-comfy.utils.get_attr(self.model, key),), strength_model))
542+
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
515543
current_hooked_patches[key] = current_patches
516-
self.hooked_patches[lora_hook] = current_hooked_patches
544+
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
517545
# since should care about these patches too to determine if same model, reroll patches_uuid
518546
self.patches_uuid = uuid.uuid4()
519547
return list(p)
@@ -526,7 +554,7 @@ def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup):
526554
combined_patches = {}
527555
if lora_hooks is not None:
528556
for hook in lora_hooks.hooks:
529-
hook_patches: dict = self.hooked_patches.get(hook, {})
557+
hook_patches: dict = self.hooked_patches.get(hook.hook_ref, {})
530558
for key in hook_patches.keys():
531559
current_patches: list[tuple] = combined_patches.get(key, [])
532560
current_patches.extend(hook_patches[key])

0 commit comments

Comments
 (0)