17
17
from PIL import Image
18
18
from safetensors import safe_open
19
19
from tqdm import tqdm
20
+ import re
20
21
21
22
import sd3_impls
22
23
from other_impls import SD3Tokenizer , SDClipModel , SDXLClipG , T5XXLModel
@@ -61,7 +62,9 @@ def load_into(ckpt, model, prefix, device, dtype=None, remap=None):
61
62
obj .requires_grad_ (False )
62
63
# print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}")
63
64
if obj .shape != tensor .shape :
64
- print (f"W: shape mismatch for key { model_key } , { obj .shape } != { tensor .shape } " )
65
+ print (
66
+ f"W: shape mismatch for key { model_key } , { obj .shape } != { tensor .shape } "
67
+ )
65
68
obj .set_ (tensor )
66
69
except Exception as e :
67
70
print (f"Failed to load key '{ key } ' in safetensors file: { e } " )
@@ -148,6 +151,11 @@ class SD3:
148
151
def __init__ (
149
152
self , model , shift , control_model_file = None , verbose = False , device = "cpu"
150
153
):
154
+
155
+ # NOTE 8B ControlNets were trained with a slightly different forward pass and conditioning,
156
+ # so this is a flag to enable that logic.
157
+ self .using_8b_controlnet = False
158
+
151
159
with safe_open (model , framework = "pt" , device = "cpu" ) as f :
152
160
control_model_ckpt = None
153
161
if control_model_file is not None :
@@ -165,9 +173,6 @@ def __init__(
165
173
).eval ()
166
174
load_into (f , self .model , "model." , "cuda" , torch .float16 )
167
175
if control_model_file is not None :
168
- self .model .control_model = self .model .control_model .to (
169
- device = device , dtype = torch .float16
170
- )
171
176
control_model_ckpt = safe_open (
172
177
control_model_file , framework = "pt" , device = device
173
178
)
@@ -179,6 +184,9 @@ def __init__(
179
184
dtype = torch .float16 ,
180
185
remap = CONTROLNET_MAP ,
181
186
)
187
+
188
+ self .using_8b_controlnet = self .model .control_model .y_embedder .mlp [0 ].in_features == 2048
189
+ self .model .control_model .using_8b_controlnet = self .using_8b_controlnet
182
190
control_model_ckpt = None
183
191
184
192
@@ -252,7 +260,7 @@ def load(
252
260
model_folder : str = MODEL_FOLDER ,
253
261
text_encoder_device : str = "cpu" ,
254
262
verbose = False ,
255
- load_tokenizers : bool = True
263
+ load_tokenizers : bool = True ,
256
264
):
257
265
self .verbose = verbose
258
266
print ("Loading tokenizers..." )
@@ -374,19 +382,19 @@ def do_sampling(
374
382
self .print ("Sampling done" )
375
383
return latent
376
384
377
- def vae_encode (self , image , controlnet_cond : bool = False ) -> torch .Tensor :
385
+ def vae_encode (self , image , using_8b_controlnet : bool = False ) -> torch .Tensor :
378
386
self .print ("Encoding image to latent..." )
379
387
image = image .convert ("RGB" )
380
388
image_np = np .array (image ).astype (np .float32 ) / 255.0
381
389
image_np = np .moveaxis (image_np , 2 , 0 )
382
390
batch_images = np .expand_dims (image_np , axis = 0 ).repeat (1 , axis = 0 )
383
391
image_torch = torch .from_numpy (batch_images ).cuda ()
384
- if not controlnet_cond :
392
+ if using_8b_controlnet :
385
393
image_torch = 2.0 * image_torch - 1.0
394
+ else :
395
+ image_torch = image_torch * 255
386
396
image_torch = image_torch .cuda ()
387
397
self .vae .model = self .vae .model .cuda ()
388
- if controlnet_cond :
389
- image_torch = image_torch * 255
390
398
latent = self .vae .model .encode (image_torch ).cpu ()
391
399
self .vae .model = self .vae .model .cpu ()
392
400
self .print ("Encoded" )
@@ -411,10 +419,10 @@ def vae_decode(self, latent) -> Image.Image:
411
419
self .print ("Decoded" )
412
420
return out_image
413
421
414
- def _image_to_latent (self , image , width , height , controlnet_cond : bool = False ):
422
+ def _image_to_latent (self , image , width , height , using_8b_controlnet : bool = False ):
415
423
image_data = Image .open (image )
416
424
image_data = image_data .resize ((width , height ), Image .LANCZOS )
417
- latent = self .vae_encode (image_data , controlnet_cond )
425
+ latent = self .vae_encode (image_data , using_8b_controlnet )
418
426
latent = SD3LatentFormat ().process_in (latent )
419
427
return latent
420
428
@@ -442,7 +450,7 @@ def gen_image(
442
450
latent = latent .cuda ()
443
451
if controlnet_cond_image :
444
452
controlnet_cond = self ._image_to_latent (
445
- controlnet_cond_image , width , height , True
453
+ controlnet_cond_image , width , height , self . sd3 . using_8b_controlnet
446
454
)
447
455
neg_cond = self .get_cond ("" )
448
456
seed_num = None
@@ -468,8 +476,9 @@ def gen_image(
468
476
skip_layer_config ,
469
477
)
470
478
image = self .vae_decode (sampled_latent )
479
+ os .makedirs (out_dir , exist_ok = False )
471
480
save_path = os .path .join (out_dir , f"{ i :06d} .png" )
472
- self .print (f"Will save to { save_path } " )
481
+ self .print (f"Saving to to { save_path } " )
473
482
image .save (save_path )
474
483
self .print ("Done" )
475
484
@@ -553,7 +562,13 @@ def main(
553
562
inferencer = SD3Inferencer ()
554
563
555
564
inferencer .load (
556
- model , vae , shift , controlnet_ckpt , model_folder , text_encoder_device , verbose
565
+ model ,
566
+ vae ,
567
+ shift ,
568
+ controlnet_ckpt ,
569
+ model_folder ,
570
+ text_encoder_device ,
571
+ verbose ,
557
572
)
558
573
559
574
if isinstance (prompt , str ):
@@ -563,6 +578,7 @@ def main(
563
578
else :
564
579
prompts = [prompt ]
565
580
581
+ sanitized_prompt = re .sub (r'[^\w\-\.]' , '_' , prompt )
566
582
out_dir = os .path .join (
567
583
out_dir ,
568
584
(
@@ -573,11 +589,9 @@ def main(
573
589
else ""
574
590
)
575
591
),
576
- os .path .splitext (os .path .basename (prompt ))[0 ][:50 ]
592
+ os .path .splitext (os .path .basename (sanitized_prompt ))[0 ][:50 ]
577
593
+ (postfix or datetime .datetime .now ().strftime ("_%Y-%m-%dT%H-%M-%S" )),
578
594
)
579
- print (f"Saving images to { out_dir } " )
580
- os .makedirs (out_dir , exist_ok = False )
581
595
582
596
inferencer .gen_image (
583
597
prompts ,
0 commit comments