17
17
from PIL import Image
18
18
from safetensors import safe_open
19
19
from tqdm import tqdm
20
+ import re
20
21
21
22
import sd3_impls
22
23
from other_impls import SD3Tokenizer , SDClipModel , SDXLClipG , T5XXLModel
@@ -168,9 +169,6 @@ def __init__(
168
169
).eval ()
169
170
load_into (f , self .model , "model." , "cuda" , torch .float16 )
170
171
if control_model_file is not None :
171
- self .model .control_model = self .model .control_model .to (
172
- device = device , dtype = torch .float16
173
- )
174
172
control_model_ckpt = safe_open (
175
173
control_model_file , framework = "pt" , device = device
176
174
)
@@ -388,8 +386,6 @@ def vae_encode(self, image, controlnet_cond: bool = False) -> torch.Tensor:
388
386
image_torch = 2.0 * image_torch - 1.0
389
387
image_torch = image_torch .cuda ()
390
388
self .vae .model = self .vae .model .cuda ()
391
- if controlnet_cond :
392
- image_torch = image_torch * 255
393
389
latent = self .vae .model .encode (image_torch ).cpu ()
394
390
self .vae .model = self .vae .model .cpu ()
395
391
self .print ("Encoded" )
@@ -426,6 +422,7 @@ def _image_to_latent(self, image, width, height, controlnet_cond: bool = False):
426
422
image_data = Image .open (image )
427
423
image_data = image_data .resize ((width , height ), Image .LANCZOS )
428
424
latent = self .vae_encode (image_data , controlnet_cond )
425
+ # latent, _ = DiagonalGaussianRegularizer()(latent)
429
426
latent = SD3LatentFormat ().process_in (latent )
430
427
return latent
431
428
@@ -480,8 +477,9 @@ def gen_image(
480
477
skip_layer_config ,
481
478
)
482
479
image = self .vae_decode (sampled_latent )
480
+ os .makedirs (out_dir , exist_ok = False )
483
481
save_path = os .path .join (out_dir , f"{ i :06d} .png" )
484
- self .print (f"Will save to { save_path } " )
482
+ self .print (f"Saving to to { save_path } " )
485
483
image .save (save_path )
486
484
self .print ("Done" )
487
485
@@ -572,7 +570,6 @@ def main(
572
570
model_folder ,
573
571
text_encoder_device ,
574
572
verbose ,
575
- load_tokenizers = False ,
576
573
)
577
574
578
575
if isinstance (prompt , str ):
@@ -582,6 +579,7 @@ def main(
582
579
else :
583
580
prompts = [prompt ]
584
581
582
+ sanitized_prompt = re .sub (r'[^\w\-\.]' , '_' , prompt )
585
583
out_dir = os .path .join (
586
584
out_dir ,
587
585
(
@@ -592,11 +590,9 @@ def main(
592
590
else ""
593
591
)
594
592
),
595
- os .path .splitext (os .path .basename (prompt ))[0 ][:50 ]
593
+ os .path .splitext (os .path .basename (sanitized_prompt ))[0 ][:50 ]
596
594
+ (postfix or datetime .datetime .now ().strftime ("_%Y-%m-%dT%H-%M-%S" )),
597
595
)
598
- print (f"Saving images to { out_dir } " )
599
- os .makedirs (out_dir , exist_ok = False )
600
596
601
597
inferencer .gen_image (
602
598
prompts ,
0 commit comments