1717from PIL import Image
1818from safetensors import safe_open
1919from tqdm import tqdm
20+ import re
2021
2122import sd3_impls
2223from other_impls import SD3Tokenizer , SDClipModel , SDXLClipG , T5XXLModel
@@ -61,7 +62,9 @@ def load_into(ckpt, model, prefix, device, dtype=None, remap=None):
6162 obj .requires_grad_ (False )
6263 # print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}")
6364 if obj .shape != tensor .shape :
64- print (f"W: shape mismatch for key { model_key } , { obj .shape } != { tensor .shape } " )
65+ print (
66+ f"W: shape mismatch for key { model_key } , { obj .shape } != { tensor .shape } "
67+ )
6568 obj .set_ (tensor )
6669 except Exception as e :
6770 print (f"Failed to load key '{ key } ' in safetensors file: { e } " )
@@ -148,6 +151,11 @@ class SD3:
148151 def __init__ (
149152 self , model , shift , control_model_file = None , verbose = False , device = "cpu"
150153 ):
154+
155+ # NOTE 8B ControlNets were trained with a slightly different forward pass and conditioning,
156+ # so this is a flag to enable that logic.
157+ self .using_8b_controlnet = False
158+
151159 with safe_open (model , framework = "pt" , device = "cpu" ) as f :
152160 control_model_ckpt = None
153161 if control_model_file is not None :
@@ -165,9 +173,6 @@ def __init__(
165173 ).eval ()
166174 load_into (f , self .model , "model." , "cuda" , torch .float16 )
167175 if control_model_file is not None :
168- self .model .control_model = self .model .control_model .to (
169- device = device , dtype = torch .float16
170- )
171176 control_model_ckpt = safe_open (
172177 control_model_file , framework = "pt" , device = device
173178 )
@@ -179,6 +184,9 @@ def __init__(
179184 dtype = torch .float16 ,
180185 remap = CONTROLNET_MAP ,
181186 )
187+
188+ self .using_8b_controlnet = self .model .control_model .y_embedder .mlp [0 ].in_features == 2048
189+ self .model .control_model .using_8b_controlnet = self .using_8b_controlnet
182190 control_model_ckpt = None
183191
184192
@@ -252,7 +260,7 @@ def load(
252260 model_folder : str = MODEL_FOLDER ,
253261 text_encoder_device : str = "cpu" ,
254262 verbose = False ,
255- load_tokenizers : bool = True
263+ load_tokenizers : bool = True ,
256264 ):
257265 self .verbose = verbose
258266 print ("Loading tokenizers..." )
@@ -374,19 +382,19 @@ def do_sampling(
374382 self .print ("Sampling done" )
375383 return latent
376384
377- def vae_encode (self , image , controlnet_cond : bool = False ) -> torch .Tensor :
385+ def vae_encode (self , image , using_8b_controlnet : bool = False ) -> torch .Tensor :
378386 self .print ("Encoding image to latent..." )
379387 image = image .convert ("RGB" )
380388 image_np = np .array (image ).astype (np .float32 ) / 255.0
381389 image_np = np .moveaxis (image_np , 2 , 0 )
382390 batch_images = np .expand_dims (image_np , axis = 0 ).repeat (1 , axis = 0 )
383391 image_torch = torch .from_numpy (batch_images ).cuda ()
384- if not controlnet_cond :
392+ if using_8b_controlnet :
385393 image_torch = 2.0 * image_torch - 1.0
394+ else :
395+ image_torch = image_torch * 255
386396 image_torch = image_torch .cuda ()
387397 self .vae .model = self .vae .model .cuda ()
388- if controlnet_cond :
389- image_torch = image_torch * 255
390398 latent = self .vae .model .encode (image_torch ).cpu ()
391399 self .vae .model = self .vae .model .cpu ()
392400 self .print ("Encoded" )
@@ -411,10 +419,10 @@ def vae_decode(self, latent) -> Image.Image:
411419 self .print ("Decoded" )
412420 return out_image
413421
414- def _image_to_latent (self , image , width , height , controlnet_cond : bool = False ):
422+ def _image_to_latent (self , image , width , height , using_8b_controlnet : bool = False ):
415423 image_data = Image .open (image )
416424 image_data = image_data .resize ((width , height ), Image .LANCZOS )
417- latent = self .vae_encode (image_data , controlnet_cond )
425+ latent = self .vae_encode (image_data , using_8b_controlnet )
418426 latent = SD3LatentFormat ().process_in (latent )
419427 return latent
420428
@@ -442,7 +450,7 @@ def gen_image(
442450 latent = latent .cuda ()
443451 if controlnet_cond_image :
444452 controlnet_cond = self ._image_to_latent (
445- controlnet_cond_image , width , height , True
453+ controlnet_cond_image , width , height , self . sd3 . using_8b_controlnet
446454 )
447455 neg_cond = self .get_cond ("" )
448456 seed_num = None
@@ -468,8 +476,9 @@ def gen_image(
468476 skip_layer_config ,
469477 )
470478 image = self .vae_decode (sampled_latent )
479+ os .makedirs (out_dir , exist_ok = False )
471480 save_path = os .path .join (out_dir , f"{ i :06d} .png" )
472- self .print (f"Will save to { save_path } " )
481+ self .print (f"Saving to to { save_path } " )
473482 image .save (save_path )
474483 self .print ("Done" )
475484
@@ -553,7 +562,13 @@ def main(
553562 inferencer = SD3Inferencer ()
554563
555564 inferencer .load (
556- model , vae , shift , controlnet_ckpt , model_folder , text_encoder_device , verbose
565+ model ,
566+ vae ,
567+ shift ,
568+ controlnet_ckpt ,
569+ model_folder ,
570+ text_encoder_device ,
571+ verbose ,
557572 )
558573
559574 if isinstance (prompt , str ):
@@ -563,6 +578,7 @@ def main(
563578 else :
564579 prompts = [prompt ]
565580
581+ sanitized_prompt = re .sub (r'[^\w\-\.]' , '_' , prompt )
566582 out_dir = os .path .join (
567583 out_dir ,
568584 (
@@ -573,11 +589,9 @@ def main(
573589 else ""
574590 )
575591 ),
576- os .path .splitext (os .path .basename (prompt ))[0 ][:50 ]
592+ os .path .splitext (os .path .basename (sanitized_prompt ))[0 ][:50 ]
577593 + (postfix or datetime .datetime .now ().strftime ("_%Y-%m-%dT%H-%M-%S" )),
578594 )
579- print (f"Saving images to { out_dir } " )
580- os .makedirs (out_dir , exist_ok = False )
581595
582596 inferencer .gen_image (
583597 prompts ,
0 commit comments