Skip to content

Commit 6ae9733

Browse files
minor cleanup
1 parent 232df68 commit 6ae9733

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

dit_embedder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
dtype=dtype,
6363
device=device
6464
)
65-
self.is_8b = True
65+
self.is_8b = False
6666

6767
def forward(
6868
self,

sd3_infer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -375,14 +375,17 @@ def do_sampling(
375375
self.print("Sampling done")
376376
return latent
377377

378-
def vae_encode(self, image, controlnet_cond: bool = False) -> torch.Tensor:
378+
def vae_encode(self, image, using_2b_controlnet: bool = False) -> torch.Tensor:
379379
self.print("Encoding image to latent...")
380380
image = image.convert("RGB")
381381
image_np = np.array(image).astype(np.float32) / 255.0
382382
image_np = np.moveaxis(image_np, 2, 0)
383383
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
384384
image_torch = torch.from_numpy(batch_images).cuda()
385-
image_torch = 2.0 * image_torch - 1.0
385+
if using_2b_controlnet:
386+
image_torch = image_torch * 255
387+
else:
388+
image_torch = 2.0 * image_torch - 1.0
386389
image_torch = image_torch.cuda()
387390
self.vae.model = self.vae.model.cuda()
388391
latent = self.vae.model.encode(image_torch).cpu()

0 commit comments

Comments
 (0)