Skip to content

Commit ae9bc7b

Browse files
authored
Merge PR #382 from Kosinkadink/develop - lowvram fix
Made MotionModelPatcher handle pe's properly with lowvram
2 parents 19cfb6b + 565261a commit ae9bc7b

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

animatediff/model_injection.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,11 +702,26 @@ def __init__(self, *args, **kwargs):
702702
self.was_within_range = False
703703
self.prev_sub_idxs = None
704704
self.prev_batched_number = None
705+
706+
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, *args, **kwargs):
707+
patched_model = super().patch_model_lowvram(device_to, lowvram_model_memory, force_patch_weights, *args, **kwargs)
708+
709+
# figure out the tensors (likely pe's) that should be cast to device besides just the named_modules
710+
remaining_tensors = list(self.model.state_dict().keys())
711+
named_modules = []
712+
for n, _ in self.model.named_modules():
713+
named_modules.append(n)
714+
named_modules.append(f"{n}.weight")
715+
named_modules.append(f"{n}.bias")
716+
for name in named_modules:
717+
if name in remaining_tensors:
718+
remaining_tensors.remove(name)
719+
720+
for key in remaining_tensors:
721+
self.patch_weight_to_device(key, device_to)
722+
if device_to is not None:
723+
comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to))
705724

706-
def patch_model(self, *args, **kwargs):
707-
# patch as normal; used to need to do prepare_weights call to work with lowvram, but no longer needed
708-
# will consider removing this override at some point since it does nothing at the moment
709-
patched_model = super().patch_model(*args, **kwargs)
710725
return patched_model
711726

712727
def pre_run(self, model: ModelPatcherAndInjector):

animatediff/sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
249249
self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule
250250
self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds
251251
self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames
252-
self.orig_groupnorm_manual_cast_forward = comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights
252+
self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights
253253
self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers
254254
self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult
255255
if SAMPLE_FALLBACK: # for backwards compatibility, for now
@@ -267,7 +267,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
267267
if not (info.mm_version == AnimateDiffVersion.V3 or
268268
(info.mm_format not in [AnimateDiffFormat.HOTSHOTXL] and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_properly)):
269269
torch.nn.GroupNorm.forward = groupnorm_mm_factory(params)
270-
comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True)
270+
comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True)
271271
# if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack
272272
try:
273273
if model.load_device.type == "mps":
@@ -293,7 +293,7 @@ def restore_functions(self, model: ModelPatcherAndInjector):
293293
model.model.memory_required = self.orig_memory_required
294294
openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed
295295
torch.nn.GroupNorm.forward = self.orig_groupnorm_forward
296-
comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_manual_cast_forward
296+
comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights
297297
comfy.samplers.sampling_function = self.orig_sampling_function
298298
comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult
299299
if SAMPLE_FALLBACK: # for backwards compatibility, for now

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-animatediff-evolved"
33
description = "Improved AnimateDiff integration for ComfyUI."
4-
version = "1.0.0"
4+
version = "1.0.1"
55
license = "LICENSE"
66
dependencies = []
77

0 commit comments

Comments
 (0)