@@ -57,7 +57,7 @@ def load_into(ckpt, model, prefix, device, dtype=None, remap=None):
57
57
continue
58
58
try :
59
59
tensor = ckpt .get_tensor (key ).to (device = device )
60
- if dtype is not None :
60
+ if dtype is not None and tensor . dtype != torch . int32 :
61
61
tensor = tensor .to (dtype = dtype )
62
62
obj .requires_grad_ (False )
63
63
# print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}")
@@ -385,15 +385,19 @@ def do_sampling(
385
385
self .print ("Sampling done" )
386
386
return latent
387
387
388
- def vae_encode (self , image , using_2b_controlnet : bool = False ) -> torch .Tensor :
388
+ def vae_encode (
389
+ self , image , using_2b_controlnet : bool = False , controlnet_type : int = 0
390
+ ) -> torch .Tensor :
389
391
self .print ("Encoding image to latent..." )
390
392
image = image .convert ("RGB" )
391
393
image_np = np .array (image ).astype (np .float32 ) / 255.0
392
394
image_np = np .moveaxis (image_np , 2 , 0 )
393
395
batch_images = np .expand_dims (image_np , axis = 0 ).repeat (1 , axis = 0 )
394
396
image_torch = torch .from_numpy (batch_images ).cuda ()
395
397
if using_2b_controlnet :
396
- image_torch = image_torch * 255
398
+ image_torch = image_torch * 2.0 - 1.0
399
+ elif controlnet_type == 1 : # canny
400
+ image_torch = image_torch * 255 * 0.5 + 0.5
397
401
else :
398
402
image_torch = 2.0 * image_torch - 1.0
399
403
image_torch = image_torch .cuda ()
@@ -422,10 +426,17 @@ def vae_decode(self, latent) -> Image.Image:
422
426
self .print ("Decoded" )
423
427
return out_image
424
428
425
- def _image_to_latent (self , image , width , height , using_2b_controlnet : bool = False ):
429
+ def _image_to_latent (
430
+ self ,
431
+ image ,
432
+ width ,
433
+ height ,
434
+ using_2b_controlnet : bool = False ,
435
+ controlnet_type : int = 0 ,
436
+ ) -> torch .Tensor :
426
437
image_data = Image .open (image )
427
438
image_data = image_data .resize ((width , height ), Image .LANCZOS )
428
- latent = self .vae_encode (image_data , using_2b_controlnet )
439
+ latent = self .vae_encode (image_data , using_2b_controlnet , controlnet_type )
429
440
latent = SD3LatentFormat ().process_in (latent )
430
441
return latent
431
442
@@ -452,12 +463,12 @@ def gen_image(
452
463
latent = self .get_empty_latent (1 , width , height , seed , "cpu" )
453
464
latent = latent .cuda ()
454
465
if controlnet_cond_image :
455
- using_2b_controlnet = (
456
- self .sd3 .model .control_model is not None
457
- and not self .sd3 .using_8b_controlnet
458
- )
466
+ using_2b , control_type = False , 0
467
+ if self .sd3 .model .control_model is not None :
468
+ using_2b = not self .sd3 .using_8b_controlnet
469
+ control_type = int ( self . sd3 . model . control_model . control_type . item () )
459
470
controlnet_cond = self ._image_to_latent (
460
- controlnet_cond_image , width , height , using_2b_controlnet
471
+ controlnet_cond_image , width , height , using_2b , control_type
461
472
)
462
473
neg_cond = self .get_cond ("" )
463
474
seed_num = None
0 commit comments