@@ -249,7 +249,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
249
249
self .orig_forward_timestep_embed = openaimodel .forward_timestep_embed # needed to account for VanillaTemporalModule
250
250
self .orig_memory_required = model .model .memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds
251
251
self .orig_groupnorm_forward = torch .nn .GroupNorm .forward # used to normalize latents to remove "flickering" of colors/brightness between frames
252
- self .orig_groupnorm_manual_cast_forward = comfy .ops .manual_cast .GroupNorm .forward_comfy_cast_weights
252
+ self .orig_groupnorm_forward_comfy_cast_weights = comfy .ops .disable_weight_init .GroupNorm .forward_comfy_cast_weights
253
253
self .orig_sampling_function = comfy .samplers .sampling_function # used to support sliding context windows in samplers
254
254
self .orig_get_area_and_mult = comfy .samplers .get_area_and_mult
255
255
if SAMPLE_FALLBACK : # for backwards compatibility, for now
@@ -267,7 +267,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
267
267
if not (info .mm_version == AnimateDiffVersion .V3 or
268
268
(info .mm_format not in [AnimateDiffFormat .HOTSHOTXL ] and info .sd_type == ModelTypeSD .SD1_5 and info .mm_version == AnimateDiffVersion .V2 and params .apply_v2_properly )):
269
269
torch .nn .GroupNorm .forward = groupnorm_mm_factory (params )
270
- comfy .ops .manual_cast .GroupNorm .forward_comfy_cast_weights = groupnorm_mm_factory (params , manual_cast = True )
270
+ comfy .ops .disable_weight_init .GroupNorm .forward_comfy_cast_weights = groupnorm_mm_factory (params , manual_cast = True )
271
271
# if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack
272
272
try :
273
273
if model .load_device .type == "mps" :
@@ -293,7 +293,7 @@ def restore_functions(self, model: ModelPatcherAndInjector):
293
293
model .model .memory_required = self .orig_memory_required
294
294
openaimodel .forward_timestep_embed = self .orig_forward_timestep_embed
295
295
torch .nn .GroupNorm .forward = self .orig_groupnorm_forward
296
- comfy .ops .manual_cast .GroupNorm .forward_comfy_cast_weights = self .orig_groupnorm_manual_cast_forward
296
+ comfy .ops .disable_weight_init .GroupNorm .forward_comfy_cast_weights = self .orig_groupnorm_forward_comfy_cast_weights
297
297
comfy .samplers .sampling_function = self .orig_sampling_function
298
298
comfy .samplers .get_area_and_mult = self .orig_get_area_and_mult
299
299
if SAMPLE_FALLBACK : # for backwards compatibility, for now
0 commit comments