-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathembedding_editor.py
284 lines (226 loc) · 13.5 KB
/
embedding_editor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import os
from modules.call_queue import wrap_gradio_gpu_call
from modules import scripts, script_callbacks
from modules import shared, devices, sd_hijack, processing, sd_models, images, ui
from modules.shared import opts, cmd_opts, restricted_opts
from modules.ui import create_output_panel, setup_progressbar, create_refresh_button
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.ui import plaintext_to_html
from modules.textual_inversion.textual_inversion import save_embedding
import gradio as gr
import gradio.routes
import gradio.utils
import torch
# ISSUES
# distribution shouldn't be fetched until the first embedding is opened, and can probably be converted into a numpy array
# most functions need to verify that an embedding is selected
# vector numbers aren't verified (might be better as a slider)
# weight slider values are lost when changing vector number
# remove unused imports
#
# TODO
# add tagged positions on sliders from user-supplied words (and unique symbols & colours)
# add a word->substrings printout for use with the above for words which map to multiple embeddings (e.g. "computer" = "compu" and "ter")
# add the ability to create embeddings which are a mix of other embeddings (with ratios), e.g. 0.5 * skunk + 0.5 * puppy is a valid embedding
# add the ability to shift all weights towards another embedding with a master slider
# add a strength slider (multiply all weights)
# print out the closest word(s) in the original embeddings list to the current embedding, with torch.abs(embedding1.vec - embedding2.vec).mean() or maybe sum
# also maybe print a mouseover or have an expandable per weight slider for the closest embedding(s) for that weight value
# maybe allowing per-weight notes, and possibly a way to save them per embedding vector
# add option to vary individual weights one at a time and geneerate outputs, potentially also combinations of weights. Potentially use scoring system to determine size of change (maybe latents or clip interrogator)
# add option to 'move' around current embedding position and generate outputs (a 768-dimensional vector spiral)?
embedding_editor_weight_visual_scalar = 1
def determine_embedding_distribution():
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
# fix for medvram/lowvram - can't figure out how to detect the device of the model in torch, so will try to guess from the web ui options
device = devices.device
if cmd_opts.medvram or cmd_opts.lowvram:
device = torch.device("cpu")
#
for i in range(49405): # guessing that's the range of CLIP tokens given that 49406 and 49407 are special tokens presumably appended to the end
embedding = embedding_layer.token_embedding.wrapped(torch.LongTensor([i]).to(device)).squeeze(0)
if i == 0:
distribution_floor = embedding
distribution_ceiling = embedding
else:
distribution_floor = torch.minimum(distribution_floor, embedding)
distribution_ceiling = torch.maximum(distribution_ceiling, embedding)
# a hack but don't know how else to get these values into gradio event functions, short of maybe caching them in an invisible gradio html element
global embedding_editor_distribution_floor, embedding_editor_distribution_ceiling
embedding_editor_distribution_floor = distribution_floor
embedding_editor_distribution_ceiling = distribution_ceiling
def build_slider(index, default, weight_sliders):
floor = embedding_editor_distribution_floor[index].item() * embedding_editor_weight_visual_scalar
ceil = embedding_editor_distribution_ceiling[index].item() * embedding_editor_weight_visual_scalar
slider = gr.Slider(minimum=floor, maximum=ceil, step="any", label=f"w{index}", value=default, interactive=True, elem_id=f'embedding_editor_weight_slider_{index}')
weight_sliders.append(slider)
def on_ui_tabs():
determine_embedding_distribution()
weight_sliders = []
with gr.Blocks(analytics_enabled=False) as embedding_editor_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel', scale=1.5):
with gr.Column():
with gr.Row():
embedding_name = gr.Dropdown(label='Embedding', elem_id="edit_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()), interactive=True)
vector_num = gr.Number(label='Vector', value=0, step=1, interactive=True)
refresh_embeddings_button = gr.Button(value="Refresh Embeddings", variant='secondary')
save_embedding_button = gr.Button(value="Save Embedding", variant='primary')
instructions = gr.HTML(f"""
<p>Enter words and color hexes to mark weights on the sliders for guidance. Hint: Use the txt2img prompt token counter or <a style="font-weight: bold;" href="https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer">webui-tokenizer</a> to see which words are constructed using multiple sub-words, e.g. 'computer' doesn't exist in stable diffusion's CLIP dictionary and instead 'compu' and 'ter' are used (1 word but 2 embedding vectors). Currently buggy and needs a moment to process before pressing the button. If it doesn't work after a moment, try adding a random space to refresh it.
</p>
""")
with gr.Row():
guidance_embeddings = gr.Textbox(value="apple:#FF0000, banana:#FECE26, strawberry:#FF00FF", placeholder="symbol:color-hex, symbol:color-hex, ...", show_label=False, interactive=True)
guidance_update_button = gr.Button(value='\U0001f504', elem_id='embedding_editor_refresh_guidance')
guidance_hidden_cache = gr.HTML(value="", visible=False)
with gr.Column(elem_id='embedding_editor_weight_sliders_container'):
for i in range(0, 128):
with gr.Row():
build_slider(i*6+0, 0, weight_sliders)
build_slider(i*6+1, 0, weight_sliders)
build_slider(i*6+2, 0, weight_sliders)
build_slider(i*6+3, 0, weight_sliders)
build_slider(i*6+4, 0, weight_sliders)
build_slider(i*6+5, 0, weight_sliders)
with gr.Column(scale=1):
gallery = gr.Gallery(label='Output', show_label=False, elem_id="embedding_editor_gallery").style(grid=4)
prompt = gr.Textbox(label="Prompt", elem_id=f"embedding_editor_prompt", show_label=False, lines=2, placeholder="e.g. A portrait photo of embedding_name" )
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
seed =(gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
with gr.Row():
generate_preview = gr.Button(value="Generate Preview", variant='primary')
generation_info = gr.HTML()
html_info = gr.HTML()
preview_args = dict(
fn=wrap_gradio_gpu_call(generate_embedding_preview),
#_js="submit",
inputs=[
embedding_name,
vector_num,
prompt,
steps,
cfg_scale,
seed,
batch_count,
] + weight_sliders,
outputs=[
gallery,
generation_info,
html_info
],
show_progress=False,
)
generate_preview.click(**preview_args)
selection_args = dict(
fn=select_embedding,
inputs=[
embedding_name,
vector_num,
],
outputs = weight_sliders,
)
embedding_name.change(**selection_args)
vector_num.change(**selection_args)
def refresh_embeddings():
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # refresh_method
refreshed_args = lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())} # refreshed_args
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(embedding_name, k, v)
return gr.update(**(args or {}))
refresh_embeddings_button.click(
fn=refresh_embeddings,
inputs=[],
outputs=[embedding_name]
)
save_embedding_button.click(
fn=save_embedding_weights,
inputs=[
embedding_name,
vector_num,
] + weight_sliders,
outputs=[],
)
guidance_embeddings.change(
fn=update_guidance_embeddings,
inputs=[guidance_embeddings],
outputs=[guidance_hidden_cache]
)
guidance_update_button.click(
fn=None,
_js="embedding_editor_update_guidance",
inputs=[guidance_hidden_cache],
outputs=[]
)
guidance_hidden_cache.value = update_guidance_embeddings(guidance_embeddings.value)
return [(embedding_editor_interface, "Embedding Editor", "embedding_editor_interface")]
def select_embedding(embedding_name, vector_num):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings[embedding_name]
vec = embedding.vec[int(vector_num)]
weights = []
for i in range(0, 768):
weights.append( vec[i].item() * embedding_editor_weight_visual_scalar )
return weights
def apply_slider_weights(embedding_name, vector_num, weights):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings[embedding_name]
vec = embedding.vec[int(vector_num)]
old_weights = []
for i in range(0, 768):
old_weights.append(vec[i].item())
vec[i] = weights[i] / embedding_editor_weight_visual_scalar
return old_weights
def generate_embedding_preview(embedding_name, vector_num, prompt: str, steps: int, cfg_scale: float, seed: int, batch_count: int, *weights):
old_weights = apply_slider_weights(embedding_name, vector_num, weights)
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
seed=seed,
steps=steps,
cfg_scale=cfg_scale,
n_iter=batch_count,
)
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = process_images(p)
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
if opts.samples_log_stdout:
print(generation_info_js)
apply_slider_weights(embedding_name, vector_num, old_weights) # restore
return processed.images, generation_info_js, plaintext_to_html(processed.info)
def save_embedding_weights(embedding_name, vector_num, *weights):
apply_slider_weights(embedding_name, vector_num, weights)
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint()
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
def update_guidance_embeddings(text):
try:
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
pairs = [x.strip() for x in text.split(',')]
col_weights = {}
for pair in pairs:
word, col = pair.split(":")
ids = cond_model.tokenizer(word, max_length=77, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedding = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)[0]
weights = []
for i in range(0, 768):
weight = embedding[i].item()
floor = embedding_editor_distribution_floor[i].item()
ceiling = embedding_editor_distribution_ceiling[i].item()
weight = (weight - floor) / (ceiling - floor) # adjust to range for using as a guidance marker along the slider
weights.append(weight)
col_weights[col] = weights
return col_weights
except:
return []
script_callbacks.on_ui_tabs(on_ui_tabs)