Skip to content

Commit 0221c91

Browse files
committedNov 19, 2024·
cleanup, remove non-essential stuff
1 parent 73299cc commit 0221c91

7 files changed

+9
-266
lines changed
 

‎cn_eval.out.538931

-5
This file was deleted.

‎dit_embedder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(
5858
embed_dim=self.hidden_size,
5959
strict_img_size=False,
6060
)
61-
self.is_8b = True
6261

6362
def forward(
6463
self,
@@ -67,9 +66,10 @@ def forward(
6766
y: Tensor,
6867
scale: int = 1,
6968
timestep: Optional[Tensor] = None,
69+
is_8b: bool = False
7070
) -> Tuple[Tensor, List[Tensor]]:
7171

72-
if not self.is_8b:
72+
if not is_8b:
7373
x = self.x_embedder(x)
7474
timestep = timestep * 1000
7575
c = self.t_embedder(timestep, dtype=x.dtype)
@@ -83,7 +83,7 @@ def forward(
8383

8484
for block in self.transformer_blocks:
8585
out = block(x, c)
86-
if self.is_8b:
86+
if is_8b:
8787
x = out
8888
block_out += (out,)
8989

‎evaluate.py

-136
This file was deleted.

‎sd3_impls.py

+6-79
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,17 @@ def apply_model(self, x, sigma, c_crossattn=None, y=None, skip_layers=[], contro
159159
controlnet_cond = controlnet_cond.to(dtype=x.dtype, device=x.device)
160160
controlnet_cond = controlnet_cond.repeat(x.shape[0], 1, 1, 1)
161161

162-
# Some ControlNets don't use the y_cond input, so we need to check if it's needed.
163-
if y_cond.shape[-1] != self.control_model.y_embedder.mlp[0].in_features:
162+
# 8B ControlNets were trained with a slightly different architecture.
163+
is_8b = y_cond.shape[-1] == self.control_model.y_embedder.mlp[0].in_features
164+
if not is_8b:
164165
y_cond = self.diffusion_model.y_embedder(y)
165-
hw = x.shape[-2:]
166166

167167
x_controlnet = x
168-
# HACK
169-
# x_controlnet = torch.load("/weka/home-brianf/x_8b.pt")
170-
# controlnet_cond = torch.load("/weka/home-brianf/x_cond_8b.pt")
171-
# y_cond = torch.load("/weka/home-brianf/y_cond_8b.pt")
172-
if self.control_model.is_8b:
168+
if is_8b:
169+
hw = x.shape[-2:]
173170
x_controlnet = self.diffusion_model.x_embedder(x) + self.diffusion_model.cropped_pos_embed(hw)
174-
# y_cond[0] = torch.zeros_like(y_cond[0])
175171
controlnet_hidden_states = self.control_model(
176-
x_controlnet, controlnet_cond, y_cond, 1, sigma.to(torch.float32)
172+
x_controlnet, controlnet_cond, y_cond, 1, sigma.to(torch.float32), is_8b
177173
)
178174
model_output = self.diffusion_model(
179175
x.to(dtype),
@@ -747,72 +743,3 @@ def encode(self, image):
747743
std = torch.exp(0.5 * logvar)
748744
return mean + std * torch.randn_like(mean)
749745

750-
751-
class DiagonalGaussianDistribution:
752-
def __init__(self, parameters, deterministic=False, chunk_dim: int = 1):
753-
self.parameters = parameters
754-
self.mean, self.logvar = torch.chunk(parameters, 2, dim=chunk_dim)
755-
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
756-
self.deterministic = deterministic
757-
self.std = torch.exp(0.5 * self.logvar)
758-
self.var = torch.exp(self.logvar)
759-
if self.deterministic:
760-
self.var = self.std = torch.zeros_like(self.mean).to(
761-
device=self.parameters.device
762-
)
763-
764-
def sample(self):
765-
x = self.mean + self.std * torch.randn(self.mean.shape).to(
766-
device=self.parameters.device
767-
)
768-
return x
769-
770-
def kl(self, other=None):
771-
if self.deterministic:
772-
return torch.Tensor([0.0])
773-
else:
774-
if other is None:
775-
return 0.5 * torch.sum(
776-
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
777-
dim=list(range(1, self.mean.ndim)),
778-
)
779-
else:
780-
return 0.5 * torch.sum(
781-
torch.pow(self.mean - other.mean, 2) / other.var
782-
+ self.var / other.var
783-
- 1.0
784-
- self.logvar
785-
+ other.logvar,
786-
dim=list(range(1, self.mean.ndim)),
787-
)
788-
789-
def nll(self, sample, dims=[1, 2, 3]):
790-
if self.deterministic:
791-
return torch.Tensor([0.0])
792-
logtwopi = np.log(2.0 * np.pi)
793-
return 0.5 * torch.sum(
794-
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
795-
dim=dims,
796-
)
797-
798-
def mode(self):
799-
return self.mean
800-
801-
802-
class DiagonalGaussianRegularizer(nn.Module):
803-
def __init__(self, sample: bool = True, chunk_dim: int = 1):
804-
super().__init__()
805-
self.sample = sample
806-
self.chunk_dim = chunk_dim
807-
808-
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
809-
log = dict()
810-
posterior = DiagonalGaussianDistribution(z, chunk_dim=self.chunk_dim)
811-
if self.sample:
812-
z = posterior.sample()
813-
else:
814-
z = posterior.mode()
815-
kl_loss = posterior.kl()
816-
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
817-
log["kl_loss"] = kl_loss
818-
return z, log

‎sd3_infer.py

-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
SDVAE,
2525
BaseModel,
2626
CFGDenoiser,
27-
DiagonalGaussianRegularizer,
2827
SD3LatentFormat,
2928
SkipLayerCFGDenoiser,
3029
)
@@ -393,17 +392,8 @@ def vae_encode(self, image, controlnet_cond: bool = False) -> torch.Tensor:
393392
self.print("Encoded")
394393
return latent
395394

396-
def vae_encode_pkl(self, pkl_location: str) -> torch.Tensor:
397-
with open(pkl_location, "rb") as f:
398-
data = pickle.load(f)
399-
latent = data["vae_f8_ch16.cond.sft.latent"]
400-
latent, _ = DiagonalGaussianRegularizer()(latent)
401-
latent = SD3LatentFormat().process_in(latent)
402-
return latent
403-
404395
def vae_encode_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
405396
tensor = tensor.unsqueeze(0)
406-
latent, _ = DiagonalGaussianRegularizer()(tensor)
407397
latent = SD3LatentFormat().process_in(latent)
408398
return latent
409399

@@ -454,7 +444,6 @@ def gen_image(
454444
controlnet_cond = self._image_to_latent(
455445
controlnet_cond_image, width, height, True
456446
)
457-
# controlnet_cond = self.vae_encode_pkl("/weka/home-brianf/controlnet_val/canny_8_3/pkl/data_6.pkl")
458447
neg_cond = self.get_cond("")
459448
seed_num = None
460449
pbar = tqdm(enumerate(prompts), position=0, leave=True)

‎submit_all_evals.sh

-3
This file was deleted.

‎submit_eval.sh

-29
This file was deleted.

0 commit comments

Comments
 (0)
Please sign in to comment.