@@ -141,12 +141,19 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
141
141
# TODO: make this work with timestep scheduling
142
142
current_hooked_patches : dict [str ,list ] = self .hooked_patches .get (lora_hook .hook_ref , {})
143
143
p = set ()
144
- for key in patches :
145
- model_sd = self .model .state_dict ()
144
+ model_sd = self .model .state_dict ()
145
+ for k in patches :
146
+ offset = None
147
+ if isinstance (k , str ):
148
+ key = k
149
+ else :
150
+ offset = k [1 ]
151
+ key = k [0 ]
152
+
146
153
if key in model_sd :
147
- p .add (key )
154
+ p .add (k )
148
155
current_patches : list [tuple ] = current_hooked_patches .get (key , [])
149
- current_patches .append ((strength_patch , patches [key ], strength_model ))
156
+ current_patches .append ((strength_patch , patches [k ], strength_model , offset ))
150
157
current_hooked_patches [key ] = current_patches
151
158
self .hooked_patches [lora_hook .hook_ref ] = current_hooked_patches
152
159
# since should care about these patches too to determine if same model, reroll patches_uuid
@@ -160,13 +167,20 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
160
167
# TODO: make this work with timestep scheduling
161
168
current_hooked_patches : dict [str ,list ] = self .hooked_patches .get (lora_hook .hook_ref , {})
162
169
p = set ()
163
- for key in patches :
164
- model_sd = self .model .state_dict ()
170
+ model_sd = self .model .state_dict ()
171
+ for k in patches :
172
+ offset = None
173
+ if isinstance (k , str ):
174
+ key = k
175
+ else :
176
+ offset = k [1 ]
177
+ key = k [0 ]
178
+
165
179
if key in model_sd :
166
- p .add (key )
180
+ p .add (k )
167
181
current_patches : list [tuple ] = current_hooked_patches .get (key , [])
168
182
# take difference between desired weight and existing weight to get diff
169
- current_patches .append ((strength_patch , (patches [key ]- comfy .utils .get_attr (self .model , key ),), strength_model ))
183
+ current_patches .append ((strength_patch , (patches [k ]- comfy .utils .get_attr (self .model , key ),), strength_model , offset ))
170
184
current_hooked_patches [key ] = current_patches
171
185
self .hooked_patches [lora_hook .hook_ref ] = current_hooked_patches
172
186
# since should care about these patches too to determine if same model, reroll patches_uuid
@@ -433,7 +447,7 @@ def __init__(self, m: ModelPatcher):
433
447
if hasattr (m , "object_patches_backup" ):
434
448
self .object_patches_backup = m .object_patches_backup
435
449
# lora hook stuff
436
- self .hooked_patches = {} # binds LoraHook to specific keys
450
+ self .hooked_patches : dict [ HookRef ] = {} # binds LoraHook to specific keys
437
451
self .patches_backup = {}
438
452
self .hooked_backup : dict [str , tuple [Tensor , torch .device ]] = {}
439
453
@@ -485,16 +499,23 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
485
499
'''
486
500
Based on add_patches, but for hooked weights.
487
501
'''
488
- current_hooked_patches : dict [str ,list ] = self .hooked_patches .get (lora_hook , {})
502
+ current_hooked_patches : dict [str ,list ] = self .hooked_patches .get (lora_hook . hook_ref , {})
489
503
p = set ()
490
- for key in patches :
491
- model_sd = self .model .state_dict ()
504
+ model_sd = self .model .state_dict ()
505
+ for k in patches :
506
+ offset = None
507
+ if isinstance (k , str ):
508
+ key = k
509
+ else :
510
+ offset = k [1 ]
511
+ key = k [0 ]
512
+
492
513
if key in model_sd :
493
- p .add (key )
514
+ p .add (k )
494
515
current_patches : list [tuple ] = current_hooked_patches .get (key , [])
495
- current_patches .append ((strength_patch , patches [key ], strength_model ))
516
+ current_patches .append ((strength_patch , patches [k ], strength_model , offset ))
496
517
current_hooked_patches [key ] = current_patches
497
- self .hooked_patches [lora_hook ] = current_hooked_patches
518
+ self .hooked_patches [lora_hook . hook_ref ] = current_hooked_patches
498
519
# since should care about these patches too to determine if same model, reroll patches_uuid
499
520
self .patches_uuid = uuid .uuid4 ()
500
521
return list (p )
@@ -503,17 +524,24 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat
503
524
'''
504
525
Based on add_hooked_patches, but intended for using a model's weights as lora hook.
505
526
'''
506
- current_hooked_patches : dict [str ,list ] = self .hooked_patches .get (lora_hook , {})
527
+ current_hooked_patches : dict [str ,list ] = self .hooked_patches .get (lora_hook . hook_ref , {})
507
528
p = set ()
508
- for key in patches :
509
- model_sd = self .model .state_dict ()
529
+ model_sd = self .model .state_dict ()
530
+ for k in patches :
531
+ offset = None
532
+ if isinstance (k , str ):
533
+ key = k
534
+ else :
535
+ offset = k [1 ]
536
+ key = k [0 ]
537
+
510
538
if key in model_sd :
511
- p .add (key )
539
+ p .add (k )
512
540
current_patches : list [tuple ] = current_hooked_patches .get (key , [])
513
541
# take difference between desired weight and existing weight to get diff
514
- current_patches .append ((strength_patch , (patches [key ]- comfy .utils .get_attr (self .model , key ),), strength_model ))
542
+ current_patches .append ((strength_patch , (patches [k ]- comfy .utils .get_attr (self .model , key ),), strength_model , offset ))
515
543
current_hooked_patches [key ] = current_patches
516
- self .hooked_patches [lora_hook ] = current_hooked_patches
544
+ self .hooked_patches [lora_hook . hook_ref ] = current_hooked_patches
517
545
# since should care about these patches too to determine if same model, reroll patches_uuid
518
546
self .patches_uuid = uuid .uuid4 ()
519
547
return list (p )
@@ -526,7 +554,7 @@ def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup):
526
554
combined_patches = {}
527
555
if lora_hooks is not None :
528
556
for hook in lora_hooks .hooks :
529
- hook_patches : dict = self .hooked_patches .get (hook , {})
557
+ hook_patches : dict = self .hooked_patches .get (hook . hook_ref , {})
530
558
for key in hook_patches .keys ():
531
559
current_patches : list [tuple ] = combined_patches .get (key , [])
532
560
current_patches .extend (hook_patches [key ])
0 commit comments