Skip to content

Commit 56e772a

Browse files
authored
Use model_info.id instead of model_info.modelId (huggingface#8912)
Mention model_info.id instead of model_info.modelId
1 parent fe79489 commit 56e772a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

scripts/generate_logits.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@
103103

104104
models = api.list_models(filter="diffusers")
105105
for mod in models:
106-
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
107-
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
106+
if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256":
107+
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1]
108108

109-
print(f"Started running {mod.modelId}!!!")
109+
print(f"Started running {mod.id}!!!")
110110

111-
if mod.modelId.startswith("CompVis"):
111+
if mod.id.startswith("CompVis"):
112112
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
113113
else:
114114
model = UNet2DModel.from_pretrained(local_checkpoint)
@@ -122,6 +122,6 @@
122122
logits = model(noise, time_step).sample
123123

124124
assert torch.allclose(
125-
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
125+
logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3
126126
)
127-
print(f"{mod.modelId} has passed successfully!!!")
127+
print(f"{mod.id} has passed successfully!!!")

0 commit comments

Comments
 (0)