Skip to content

Commit d343ae0

Browse files
fixed latent encoder behavior based on control type
1 parent 50ba6fc commit d343ae0

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

dit_embedder.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def __init__(
3636
dtype=dtype,
3737
)
3838

39+
# blur = 0, canny = 1, depth = 2
40+
self.control_type = torch.tensor([0], dtype=torch.int32, device=device)
41+
3942
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device)
4043
self.y_embedder = VectorEmbedder(
4144
pooled_projection_size, self.hidden_size, dtype, device

sd3_infer.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def load_into(ckpt, model, prefix, device, dtype=None, remap=None):
5757
continue
5858
try:
5959
tensor = ckpt.get_tensor(key).to(device=device)
60-
if dtype is not None:
60+
if dtype is not None and tensor.dtype != torch.int32:
6161
tensor = tensor.to(dtype=dtype)
6262
obj.requires_grad_(False)
6363
# print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}")
@@ -385,15 +385,19 @@ def do_sampling(
385385
self.print("Sampling done")
386386
return latent
387387

388-
def vae_encode(self, image, using_2b_controlnet: bool = False) -> torch.Tensor:
388+
def vae_encode(
389+
self, image, using_2b_controlnet: bool = False, controlnet_type: int = 0
390+
) -> torch.Tensor:
389391
self.print("Encoding image to latent...")
390392
image = image.convert("RGB")
391393
image_np = np.array(image).astype(np.float32) / 255.0
392394
image_np = np.moveaxis(image_np, 2, 0)
393395
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
394396
image_torch = torch.from_numpy(batch_images).cuda()
395397
if using_2b_controlnet:
396-
image_torch = image_torch * 255
398+
image_torch = image_torch * 2.0 - 1.0
399+
elif controlnet_type == 1: # canny
400+
image_torch = image_torch * 255 * 0.5 + 0.5
397401
else:
398402
image_torch = 2.0 * image_torch - 1.0
399403
image_torch = image_torch.cuda()
@@ -422,10 +426,17 @@ def vae_decode(self, latent) -> Image.Image:
422426
self.print("Decoded")
423427
return out_image
424428

425-
def _image_to_latent(self, image, width, height, using_2b_controlnet: bool = False):
429+
def _image_to_latent(
430+
self,
431+
image,
432+
width,
433+
height,
434+
using_2b_controlnet: bool = False,
435+
controlnet_type: int = 0,
436+
) -> torch.Tensor:
426437
image_data = Image.open(image)
427438
image_data = image_data.resize((width, height), Image.LANCZOS)
428-
latent = self.vae_encode(image_data, using_2b_controlnet)
439+
latent = self.vae_encode(image_data, using_2b_controlnet, controlnet_type)
429440
latent = SD3LatentFormat().process_in(latent)
430441
return latent
431442

@@ -452,12 +463,12 @@ def gen_image(
452463
latent = self.get_empty_latent(1, width, height, seed, "cpu")
453464
latent = latent.cuda()
454465
if controlnet_cond_image:
455-
using_2b_controlnet = (
456-
self.sd3.model.control_model is not None
457-
and not self.sd3.using_8b_controlnet
458-
)
466+
using_2b, control_type = False, 0
467+
if self.sd3.model.control_model is not None:
468+
using_2b = not self.sd3.using_8b_controlnet
469+
control_type = int(self.sd3.model.control_model.control_type.item())
459470
controlnet_cond = self._image_to_latent(
460-
controlnet_cond_image, width, height, using_2b_controlnet
471+
controlnet_cond_image, width, height, using_2b, control_type
461472
)
462473
neg_cond = self.get_cond("")
463474
seed_num = None

0 commit comments

Comments
 (0)