Skip to content

Commit 232df68

Browse files
committedNov 20, 2024
more fixes for controlnet
1 parent 8075776 commit 232df68

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed
 

‎.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,5 @@ cython_debug/
167167
#.idea/
168168

169169
.vscode/
170-
*.out.*
170+
*.out.*
171+
*.pt

‎evaluate.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def load_pickles(folder_path):
5151
# Filter and sort the files
5252
pickle_files = sorted(
5353
[file for file in files if file.startswith("data_") and file.endswith(".pkl")],
54-
key=lambda x: int(x.split("_")[1].split(".")[0]),
5554
)
5655

5756
data_list = []
@@ -124,7 +123,7 @@ def main(
124123
controlnet_ckpt,
125124
model_folder,
126125
text_encoder_device,
127-
load_tokenizers=True,
126+
load_tokenizers=False,
128127
)
129128

130129
print(f"Saving images to {out_dir}")
@@ -150,8 +149,8 @@ def _get_precomputed_cond(sample):
150149
# torch.save(neg_cond[0], os.path.join(out_dir, "neg_cond_0.pt"))
151150
# torch.save(neg_cond[1], os.path.join(out_dir, "neg_cond_1.pt"))
152151
neg_cond = (
153-
torch.load(os.path.join("outputs", "neg_cond_0.pt")),
154-
torch.load(os.path.join("outputs", "neg_cond_1.pt")),
152+
torch.load("neg_cond_0.pt"),
153+
torch.load("neg_cond_1.pt"),
155154
)
156155

157156
for i, sample in tqdm(enumerate(dataset)):
@@ -161,7 +160,9 @@ def _get_precomputed_cond(sample):
161160
else:
162161
latent = inferencer.get_empty_latent(1, width, height, seed, "cpu")
163162
latent = latent.cuda()
164-
controlnet_cond = inferencer.vae_encode_tensor(sample["vae_f8_ch16.cond.sft.latent"])
163+
controlnet_cond = inferencer.vae_encode_tensor(
164+
sample["vae_f8_ch16.cond.sft.latent"]
165+
)
165166
conditioning = _get_precomputed_cond(sample)
166167
seed_num = 42
167168
sampled_latent = inferencer.do_sampling(

‎sd3_infer.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,7 @@ def vae_encode(self, image, controlnet_cond: bool = False) -> torch.Tensor:
382382
image_np = np.moveaxis(image_np, 2, 0)
383383
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
384384
image_torch = torch.from_numpy(batch_images).cuda()
385-
if not controlnet_cond:
386-
image_torch = 2.0 * image_torch - 1.0
385+
image_torch = 2.0 * image_torch - 1.0
387386
image_torch = image_torch.cuda()
388387
self.vae.model = self.vae.model.cuda()
389388
latent = self.vae.model.encode(image_torch).cpu()
@@ -400,9 +399,9 @@ def vae_encode_pkl(self, pkl_location: str) -> torch.Tensor:
400399
return latent
401400

402401
def vae_encode_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
403-
latent, _ = DiagonalGaussianRegularizer()(tensor)
404-
latent = SD3LatentFormat().process_in(latent)
405-
return latent
402+
tensor, _ = DiagonalGaussianRegularizer()(tensor)
403+
tensor = SD3LatentFormat().process_in(tensor)
404+
return tensor
406405

407406
def vae_decode(self, latent) -> Image.Image:
408407
self.print("Decoding latent to image...")

0 commit comments

Comments
 (0)
Please sign in to comment.