@@ -335,7 +335,7 @@ def forward(self, input):
335
335
336
336
337
337
class ControlLora (ControlNet ):
338
- def __init__ (self , control_weights , global_average_pooling = False , device = None ):
338
+ def __init__ (self , control_weights , global_average_pooling = False , device = None , model_options = {}): #TODO? model_options
339
339
ControlBase .__init__ (self , device )
340
340
self .control_weights = control_weights
341
341
self .global_average_pooling = global_average_pooling
@@ -392,19 +392,22 @@ def get_models(self):
392
392
def inference_memory_requirements (self , dtype ):
393
393
return comfy .utils .calculate_parameters (self .control_weights ) * comfy .model_management .dtype_size (dtype ) + ControlBase .inference_memory_requirements (self , dtype )
394
394
395
- def controlnet_config (sd ):
395
+ def controlnet_config (sd , model_options = {} ):
396
396
model_config = comfy .model_detection .model_config_from_unet (sd , "" , True )
397
397
398
398
supported_inference_dtypes = model_config .supported_inference_dtypes
399
399
400
400
controlnet_config = model_config .unet_config
401
- unet_dtype = comfy .model_management .unet_dtype (supported_dtypes = supported_inference_dtypes )
401
+ unet_dtype = model_options . get ( "dtype" , comfy .model_management .unet_dtype (supported_dtypes = supported_inference_dtypes ) )
402
402
load_device = comfy .model_management .get_torch_device ()
403
403
manual_cast_dtype = comfy .model_management .unet_manual_cast (unet_dtype , load_device )
404
- if manual_cast_dtype is not None :
405
- operations = comfy .ops .manual_cast
406
- else :
407
- operations = comfy .ops .disable_weight_init
404
+
405
+ operations = model_options .get ("custom_operations" , None )
406
+ if operations is None :
407
+ if manual_cast_dtype is not None :
408
+ operations = comfy .ops .manual_cast
409
+ else :
410
+ operations = comfy .ops .disable_weight_init
408
411
409
412
offload_device = comfy .model_management .unet_offload_device ()
410
413
return model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device
@@ -419,9 +422,9 @@ def controlnet_load_state_dict(control_model, sd):
419
422
logging .debug ("unexpected controlnet keys: {}" .format (unexpected ))
420
423
return control_model
421
424
422
- def load_controlnet_mmdit (sd ):
425
+ def load_controlnet_mmdit (sd , model_options = {} ):
423
426
new_sd = comfy .model_detection .convert_diffusers_mmdit (sd , "" )
424
- model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (new_sd )
427
+ model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (new_sd , model_options = model_options )
425
428
num_blocks = comfy .model_detection .count_blocks (new_sd , 'joint_blocks.{}.' )
426
429
for k in sd :
427
430
new_sd [k ] = sd [k ]
@@ -440,8 +443,8 @@ def load_controlnet_mmdit(sd):
440
443
return control
441
444
442
445
443
- def load_controlnet_hunyuandit (controlnet_data ):
444
- model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (controlnet_data )
446
+ def load_controlnet_hunyuandit (controlnet_data , model_options = {} ):
447
+ model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (controlnet_data , model_options = model_options )
445
448
446
449
control_model = comfy .ldm .hydit .controlnet .HunYuanControlNet (operations = operations , device = offload_device , dtype = unet_dtype )
447
450
control_model = controlnet_load_state_dict (control_model , controlnet_data )
@@ -451,17 +454,17 @@ def load_controlnet_hunyuandit(controlnet_data):
451
454
control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds , strength_type = StrengthType .CONSTANT )
452
455
return control
453
456
454
- def load_controlnet_flux_xlabs_mistoline (sd , mistoline = False ):
455
- model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (sd )
457
+ def load_controlnet_flux_xlabs_mistoline (sd , mistoline = False , model_options = {} ):
458
+ model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (sd , model_options = model_options )
456
459
control_model = comfy .ldm .flux .controlnet .ControlNetFlux (mistoline = mistoline , operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
457
460
control_model = controlnet_load_state_dict (control_model , sd )
458
461
extra_conds = ['y' , 'guidance' ]
459
462
control = ControlNet (control_model , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
460
463
return control
461
464
462
- def load_controlnet_flux_instantx (sd ):
465
+ def load_controlnet_flux_instantx (sd , model_options = {} ):
463
466
new_sd = comfy .model_detection .convert_diffusers_mmdit (sd , "" )
464
- model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (new_sd )
467
+ model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (new_sd , model_options = model_options )
465
468
for k in sd :
466
469
new_sd [k ] = sd [k ]
467
470
@@ -487,13 +490,13 @@ def convert_mistoline(sd):
487
490
return comfy .utils .state_dict_prefix_replace (sd , {"single_controlnet_blocks." : "controlnet_single_blocks." })
488
491
489
492
490
- def load_controlnet (ckpt_path , model = None ):
493
+ def load_controlnet (ckpt_path , model = None , model_options = {} ):
491
494
controlnet_data = comfy .utils .load_torch_file (ckpt_path , safe_load = True )
492
495
if 'after_proj_list.18.bias' in controlnet_data .keys (): #Hunyuan DiT
493
- return load_controlnet_hunyuandit (controlnet_data )
496
+ return load_controlnet_hunyuandit (controlnet_data , model_options = model_options )
494
497
495
498
if "lora_controlnet" in controlnet_data :
496
- return ControlLora (controlnet_data )
499
+ return ControlLora (controlnet_data , model_options = model_options )
497
500
498
501
controlnet_config = None
499
502
supported_inference_dtypes = None
@@ -550,13 +553,13 @@ def load_controlnet(ckpt_path, model=None):
550
553
controlnet_data = new_sd
551
554
elif "controlnet_blocks.0.weight" in controlnet_data :
552
555
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data :
553
- return load_controlnet_flux_xlabs_mistoline (controlnet_data )
556
+ return load_controlnet_flux_xlabs_mistoline (controlnet_data , model_options = model_options )
554
557
elif "pos_embed_input.proj.weight" in controlnet_data :
555
- return load_controlnet_mmdit (controlnet_data ) #SD3 diffusers controlnet
558
+ return load_controlnet_mmdit (controlnet_data , model_options = model_options ) #SD3 diffusers controlnet
556
559
elif "controlnet_x_embedder.weight" in controlnet_data :
557
- return load_controlnet_flux_instantx (controlnet_data )
560
+ return load_controlnet_flux_instantx (controlnet_data , model_options = model_options )
558
561
elif "controlnet_blocks.0.linear.weight" in controlnet_data : #mistoline flux
559
- return load_controlnet_flux_xlabs_mistoline (convert_mistoline (controlnet_data ), mistoline = True )
562
+ return load_controlnet_flux_xlabs_mistoline (convert_mistoline (controlnet_data ), mistoline = True , model_options = model_options )
560
563
561
564
pth_key = 'control_model.zero_convs.0.0.weight'
562
565
pth = False
@@ -568,7 +571,7 @@ def load_controlnet(ckpt_path, model=None):
568
571
elif key in controlnet_data :
569
572
prefix = ""
570
573
else :
571
- net = load_t2i_adapter (controlnet_data )
574
+ net = load_t2i_adapter (controlnet_data , model_options = model_options )
572
575
if net is None :
573
576
logging .error ("error checkpoint does not contain controlnet or t2i adapter data {}" .format (ckpt_path ))
574
577
return net
@@ -587,7 +590,10 @@ def load_controlnet(ckpt_path, model=None):
587
590
manual_cast_dtype = comfy .model_management .unet_manual_cast (unet_dtype , load_device )
588
591
if manual_cast_dtype is not None :
589
592
controlnet_config ["operations" ] = comfy .ops .manual_cast
590
- controlnet_config ["dtype" ] = unet_dtype
593
+ if "custom_operations" in model_options :
594
+ controlnet_config ["operations" ] = model_options ["custom_operations" ]
595
+ if "dtype" in model_options :
596
+ controlnet_config ["dtype" ] = model_options ["dtype" ]
591
597
controlnet_config ["device" ] = comfy .model_management .unet_offload_device ()
592
598
controlnet_config .pop ("out_channels" )
593
599
controlnet_config ["hint_channels" ] = controlnet_data ["{}input_hint_block.0.weight" .format (prefix )].shape [1 ]
@@ -685,7 +691,7 @@ def copy(self):
685
691
self .copy_to (c )
686
692
return c
687
693
688
- def load_t2i_adapter (t2i_data ):
694
+ def load_t2i_adapter (t2i_data , model_options = {}): #TODO: model_options
689
695
compression_ratio = 8
690
696
upscale_algorithm = 'nearest-exact'
691
697
0 commit comments