Skip to content

Commit ad66f7c

Browse files
Add model_options to load_controlnet function.
1 parent de8e8e3 commit ad66f7c

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

comfy/controlnet.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def forward(self, input):
335335

336336

337337
class ControlLora(ControlNet):
338-
def __init__(self, control_weights, global_average_pooling=False, device=None):
338+
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
339339
ControlBase.__init__(self, device)
340340
self.control_weights = control_weights
341341
self.global_average_pooling = global_average_pooling
@@ -392,19 +392,22 @@ def get_models(self):
392392
def inference_memory_requirements(self, dtype):
393393
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
394394

395-
def controlnet_config(sd):
395+
def controlnet_config(sd, model_options={}):
396396
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
397397

398398
supported_inference_dtypes = model_config.supported_inference_dtypes
399399

400400
controlnet_config = model_config.unet_config
401-
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
401+
unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes))
402402
load_device = comfy.model_management.get_torch_device()
403403
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
404-
if manual_cast_dtype is not None:
405-
operations = comfy.ops.manual_cast
406-
else:
407-
operations = comfy.ops.disable_weight_init
404+
405+
operations = model_options.get("custom_operations", None)
406+
if operations is None:
407+
if manual_cast_dtype is not None:
408+
operations = comfy.ops.manual_cast
409+
else:
410+
operations = comfy.ops.disable_weight_init
408411

409412
offload_device = comfy.model_management.unet_offload_device()
410413
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
@@ -419,9 +422,9 @@ def controlnet_load_state_dict(control_model, sd):
419422
logging.debug("unexpected controlnet keys: {}".format(unexpected))
420423
return control_model
421424

422-
def load_controlnet_mmdit(sd):
425+
def load_controlnet_mmdit(sd, model_options={}):
423426
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
424-
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
427+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
425428
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
426429
for k in sd:
427430
new_sd[k] = sd[k]
@@ -440,8 +443,8 @@ def load_controlnet_mmdit(sd):
440443
return control
441444

442445

443-
def load_controlnet_hunyuandit(controlnet_data):
444-
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
446+
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
447+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
445448

446449
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
447450
control_model = controlnet_load_state_dict(control_model, controlnet_data)
@@ -451,17 +454,17 @@ def load_controlnet_hunyuandit(controlnet_data):
451454
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
452455
return control
453456

454-
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
455-
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
457+
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
458+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
456459
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
457460
control_model = controlnet_load_state_dict(control_model, sd)
458461
extra_conds = ['y', 'guidance']
459462
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
460463
return control
461464

462-
def load_controlnet_flux_instantx(sd):
465+
def load_controlnet_flux_instantx(sd, model_options={}):
463466
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
464-
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
467+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
465468
for k in sd:
466469
new_sd[k] = sd[k]
467470

@@ -487,13 +490,13 @@ def convert_mistoline(sd):
487490
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
488491

489492

490-
def load_controlnet(ckpt_path, model=None):
493+
def load_controlnet(ckpt_path, model=None, model_options={}):
491494
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
492495
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
493-
return load_controlnet_hunyuandit(controlnet_data)
496+
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
494497

495498
if "lora_controlnet" in controlnet_data:
496-
return ControlLora(controlnet_data)
499+
return ControlLora(controlnet_data, model_options=model_options)
497500

498501
controlnet_config = None
499502
supported_inference_dtypes = None
@@ -550,13 +553,13 @@ def load_controlnet(ckpt_path, model=None):
550553
controlnet_data = new_sd
551554
elif "controlnet_blocks.0.weight" in controlnet_data:
552555
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
553-
return load_controlnet_flux_xlabs_mistoline(controlnet_data)
556+
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
554557
elif "pos_embed_input.proj.weight" in controlnet_data:
555-
return load_controlnet_mmdit(controlnet_data) #SD3 diffusers controlnet
558+
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
556559
elif "controlnet_x_embedder.weight" in controlnet_data:
557-
return load_controlnet_flux_instantx(controlnet_data)
560+
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
558561
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
559-
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
562+
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
560563

561564
pth_key = 'control_model.zero_convs.0.0.weight'
562565
pth = False
@@ -568,7 +571,7 @@ def load_controlnet(ckpt_path, model=None):
568571
elif key in controlnet_data:
569572
prefix = ""
570573
else:
571-
net = load_t2i_adapter(controlnet_data)
574+
net = load_t2i_adapter(controlnet_data, model_options=model_options)
572575
if net is None:
573576
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
574577
return net
@@ -587,7 +590,10 @@ def load_controlnet(ckpt_path, model=None):
587590
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
588591
if manual_cast_dtype is not None:
589592
controlnet_config["operations"] = comfy.ops.manual_cast
590-
controlnet_config["dtype"] = unet_dtype
593+
if "custom_operations" in model_options:
594+
controlnet_config["operations"] = model_options["custom_operations"]
595+
if "dtype" in model_options:
596+
controlnet_config["dtype"] = model_options["dtype"]
591597
controlnet_config["device"] = comfy.model_management.unet_offload_device()
592598
controlnet_config.pop("out_channels")
593599
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
@@ -685,7 +691,7 @@ def copy(self):
685691
self.copy_to(c)
686692
return c
687693

688-
def load_t2i_adapter(t2i_data):
694+
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
689695
compression_ratio = 8
690696
upscale_algorithm = 'nearest-exact'
691697

comfy/sd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
645645

646646
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
647647
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
648-
model_config.custom_operations = model_options.get("custom_operations", None)
648+
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
649649
model = model_config.get_model(new_sd, "")
650650
model = model.to(offload_device)
651651
model.load_model_weights(new_sd, "")

0 commit comments

Comments
 (0)