Skip to content

Commit 9acfb6d

Browse files
committed
feat: add clip_vision annotator, support non-image input
1 parent bff62ed commit 9acfb6d

11 files changed

+172
-59
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## sd-webui-controlnet
2-
(WIP) WebUI extension for ControlNet
2+
(WIP) WebUI extension for ControlNet and T2I-Adapter
33

44
This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [ControlNet](https://github.com/lllyasviel/ControlNet) to the original Stable Diffusion model to generate images. The addition is on-the-fly, the merging is not required.
55

annotator/clip/__init__.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from transformers import CLIPProcessor, CLIPVisionModel
2+
from modules import devices
3+
4+
version = 'openai/clip-vit-large-patch14'
5+
clip_proc = None
6+
clip_vision_model = None
7+
8+
def apply_clip(img):
9+
global clip_proc, clip_vision_model
10+
11+
if clip_vision_model is None:
12+
clip_proc = CLIPProcessor.from_pretrained(version)
13+
clip_vision_model = CLIPVisionModel.from_pretrained(version)
14+
15+
clip_vision_model = clip_vision_model.to(devices.get_device_for("controlnet"))
16+
style_for_clip = clip_proc(images=img, return_tensors="pt")['pixel_values']
17+
style_feat = clip_vision_model(style_for_clip.to(devices.get_device_for("controlnet")))['last_hidden_state']
18+
return style_feat
19+
20+
def unload_clip_model():
21+
global clip_proc, clip_vision_model
22+
if clip_vision_model is not None:
23+
clip_vision_model.cpu()

models/color_adapter_v14.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
target: tencentarc.t21_adapter
2+
target: scripts.adapter.Adapter_light
33
params:
44
channels: [320, 640, 1280, 1280]
55
nums_rb: 4

models/style_adapter_v14.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
target: tencentarc.t21_adapter
2+
target: scripts.adapter.StyleAdapter
33
params:
44
width: 1024
55
context_dim: 768

models/t2iadapter_color_sd14v1.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
model:
2+
target: scripts.adapter.Adapter_light
3+
params:
4+
channels: [320, 640, 1280, 1280]
5+
nums_rb: 4
6+
cin: 192

models/t2iadapter_keypose_sd14v1.yaml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
model:
2+
target: tencentarc.t21_adapter
3+
params:
4+
channels: [320, 640, 1280, 1280]
5+
nums_rb: 2
6+
ksize: 1
7+
sk: true
8+
cin: 192
9+
use_conv: false

models/t2iadapter_style_sd14v1.yaml

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
model:
2+
target: scripts.adapter.StyleAdapter
3+
params:
4+
width: 1024
5+
context_dim: 768
6+
num_head: 8
7+
n_layes: 3
8+
num_token: 8

scripts/adapter.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
import importlib
56
from collections import OrderedDict
67

78
from omegaconf import OmegaConf
@@ -55,21 +56,28 @@ def get_node_name(name, parent_name):
5556
if p != parent_name:
5657
return False, ''
5758
return True, name[len(parent_name):]
59+
60+
61+
def get_obj_from_str(string, reload=False):
62+
module, cls = string.rsplit(".", 1)
63+
if reload:
64+
module_imp = importlib.import_module(module)
65+
importlib.reload(module_imp)
66+
return getattr(importlib.import_module(module, package=None), cls)
5867

5968

6069
class PlugableAdapter(nn.Module):
6170
def __init__(self, state_dict, config_path, lowvram=False, base_model=None) -> None:
6271
super().__init__()
6372
config = OmegaConf.load(config_path)
73+
model = Adapter
74+
try:
75+
self.target = config.model.target
76+
model = get_obj_from_str(config.model.target)
77+
except ImportError:
78+
pass
6479

65-
if (config.model.params.cin == 64 * 6):
66-
config.model.params.cin = 192
67-
self.control_model = Adapter_light(**config.model.params)
68-
elif (config.model.params.cin == 64 * 7):
69-
del config.model.params.cin
70-
self.control_model = StyleAdapter(**config.model.params)
71-
else:
72-
self.control_model = Adapter(**config.model.params)
80+
self.control_model = model(**config.model.params)
7381
self.control_model.load_state_dict(state_dict)
7482
self.lowvram = lowvram
7583
self.control = None
@@ -312,6 +320,7 @@ def forward(self, x):
312320
# x shape [N, HW+1, C]
313321
style_embedding = self.style_embedding + torch.zeros(
314322
(x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
323+
315324
x = torch.cat([x, style_embedding], dim=1)
316325
x = self.ln_pre(x)
317326
x = x.permute(1, 0, 2) # NLD -> LND

scripts/controlnet.py

+53-34
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def __init__(self) -> None:
181181
"mlsd": mlsd,
182182
"normal_map": midas_normal,
183183
"openpose": openpose,
184-
# "openpose_hand": openpose_hand,
184+
"openpose_hand": openpose_hand,
185+
"clip_vision": clip,
185186
"pidinet": pidinet,
186187
"scribble": simple_scribble,
187188
"fake_scribble": fake_scribble,
@@ -191,6 +192,7 @@ def __init__(self) -> None:
191192
"hed": unload_hed,
192193
"fake_scribble": unload_hed,
193194
"mlsd": unload_mlsd,
195+
"clip": unload_clip,
194196
"depth": unload_midas,
195197
"depth_leres": unload_leres,
196198
"normal_map": unload_midas,
@@ -532,6 +534,38 @@ def parse_remote_call(self, p, params, idx):
532534

533535
return (enabled, module, model, weight, image, scribble_mode, \
534536
resize_mode, rgbbgr_mode, lowvram, pres, pthr_a, pthr_b, guidance_start, guidance_end, guess_mode), input_image
537+
538+
def detectmap_proc(self, module, rgbbgr_mode, resize_mode, h, w):
539+
detected_map = HWC3(detected_map)
540+
if module == "normal_map" or rgbbgr_mode:
541+
control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(devices.get_device_for("controlnet")) / 255.0
542+
else:
543+
control = torch.from_numpy(detected_map.copy()).float().to(devices.get_device_for("controlnet")) / 255.0
544+
545+
control = rearrange(control, 'h w c -> c h w')
546+
detected_map = rearrange(torch.from_numpy(detected_map), 'h w c -> c h w')
547+
548+
if resize_mode == "Scale to Fit (Inner Fit)":
549+
transform = Compose([
550+
Resize(h if h<w else w, interpolation=InterpolationMode.BICUBIC),
551+
CenterCrop(size=(h, w)),
552+
])
553+
control = transform(control)
554+
detected_map = transform(detected_map)
555+
elif resize_mode == "Envelope (Outer Fit)":
556+
transform = Compose([
557+
Resize(h if h>w else w, interpolation=InterpolationMode.BICUBIC),
558+
CenterCrop(size=(h, w))
559+
])
560+
control = transform(control)
561+
detected_map = transform(detected_map)
562+
else:
563+
control = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(control)
564+
detected_map = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(detected_map)
565+
566+
# for log use
567+
detected_map = rearrange(detected_map, 'c h w -> h w c').numpy().astype(np.uint8)
568+
return control, detected_map
535569

536570
def process(self, p, is_img2img=False, *args):
537571
"""
@@ -652,43 +686,28 @@ def process(self, p, is_img2img=False, *args):
652686
preprocessor = self.preprocessor[module]
653687
h, w, bsz = p.height, p.width, p.batch_size
654688
if pres > 64:
655-
detected_map = preprocessor(input_image, res=pres, thr_a=pthr_a, thr_b=pthr_b)
689+
detected_map, is_image = preprocessor(input_image, res=pres, thr_a=pthr_a, thr_b=pthr_b)
656690
else:
657-
detected_map = preprocessor(input_image)
658-
659-
detected_map = HWC3(detected_map)
660-
if module == "normal_map" or rgbbgr_mode:
661-
control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(devices.get_device_for("controlnet")) / 255.0
662-
else:
663-
control = torch.from_numpy(detected_map.copy()).float().to(devices.get_device_for("controlnet")) / 255.0
691+
detected_map, is_image = preprocessor(input_image)
664692

665-
control = rearrange(control, 'h w c -> c h w')
666-
detected_map = rearrange(torch.from_numpy(detected_map), 'h w c -> c h w')
667-
668-
if resize_mode == "Scale to Fit (Inner Fit)":
669-
transform = Compose([
670-
Resize(h if h<w else w, interpolation=InterpolationMode.BICUBIC),
671-
CenterCrop(size=(h, w)),
672-
])
673-
control = transform(control)
674-
detected_map = transform(detected_map)
675-
elif resize_mode == "Envelope (Outer Fit)":
676-
transform = Compose([
677-
Resize(h if h>w else w, interpolation=InterpolationMode.BICUBIC),
678-
CenterCrop(size=(h, w))
679-
])
680-
control = transform(control)
681-
detected_map = transform(detected_map)
693+
if is_image:
694+
control, detected_map = self.detectmap_proc(detected_map, rgbbgr_mode, resize_mode, h, w)
695+
detected_maps.append((detected_map, module))
682696
else:
683-
control = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(control)
684-
detected_map = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(detected_map)
685-
686-
# for log use
687-
detected_map = rearrange(detected_map, 'c h w -> h w c').numpy().astype(np.uint8)
688-
detected_maps.append((detected_map, module))
697+
control = detected_map
689698

690-
# hint_cond, guess_mode, weight, guidance_stopped, stop_guidance_percent, advanced_weighting
691-
forward_param = ControlParams(model_net, control, guess_mode, weight, False, guidance_start, guidance_end, None, isinstance(model_net, PlugableAdapter))
699+
forward_param = ControlParams(
700+
control_model=model_net,
701+
hint_cond=control,
702+
guess_mode=guess_mode,
703+
weight=weight,
704+
guidance_stopped=False,
705+
start_guidance_percent=guidance_start,
706+
stop_guidance_percent=guidance_end,
707+
advanced_weighting=None,
708+
is_adapter=isinstance(model_net, PlugableAdapter),
709+
is_extra_cond=getattr(model_net, "target", "") == "scripts.adapter.StyleAdapter"
710+
)
692711
forward_params.append(forward_param)
693712

694713
self.latest_network = UnetHook(lowvram=hook_lowvram)

scripts/hook.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def __init__(
4949
start_guidance_percent,
5050
stop_guidance_percent,
5151
advanced_weighting,
52-
is_adapter
52+
is_adapter,
53+
is_extra_cond
5354
):
5455
self.control_model = control_model
5556
self.hint_cond = hint_cond
@@ -60,6 +61,7 @@ def __init__(
6061
self.stop_guidance_percent = stop_guidance_percent
6162
self.advanced_weighting = advanced_weighting
6263
self.is_adapter = is_adapter
64+
self.is_extra_cond = is_extra_cond
6365

6466

6567
class UnetHook(nn.Module):
@@ -108,6 +110,7 @@ def cfg_based_adder(base, x, require_autocast, is_adapter=False):
108110
def forward(self, x, timesteps=None, context=None, **kwargs):
109111
total_control = [0.0] * 13
110112
total_adapter = [0.0] * 4
113+
total_extra_cond = torch.zeros([0, context.shape[-1]]).to(devices.get_device_for("controlnet"))
111114
only_mid_control = outer.only_mid_control
112115
require_inpaint_hijack = False
113116

@@ -138,6 +141,9 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
138141

139142
if outer.lowvram:
140143
param.control_model.to("cpu")
144+
if param.is_extra_cond:
145+
total_extra_cond = torch.cat([total_extra_cond, control.clone().squeeze(0)]) #* param.weight
146+
continue
141147
if param.guess_mode:
142148
if param.is_adapter:
143149
# see https://github.com/Mikubill/sd-webui-controlnet/issues/269
@@ -153,6 +159,19 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
153159
target[idx] += item
154160

155161
control = total_control
162+
if len(total_extra_cond) > 0 and context.shape[0] % 2 == 0:
163+
total_extra_cond = torch.repeat_interleave(total_extra_cond.unsqueeze(0), context.shape[0] // 2, dim=0)
164+
if outer.is_vanilla_samplers:
165+
uncond, cond = context.chunk(2)
166+
cond = torch.cat([cond, total_extra_cond], dim=1)
167+
uncond = torch.cat([uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1)
168+
context = torch.cat([uncond, cond], dim=0)
169+
else:
170+
cond, uncond = context.chunk(2)
171+
cond = torch.cat([cond, total_extra_cond], dim=1)
172+
uncond = torch.cat([uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1)
173+
context = torch.cat([cond, uncond], dim=0)
174+
156175
assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}")
157176
hs = []
158177
with th.no_grad():

0 commit comments

Comments
 (0)