@@ -159,21 +159,17 @@ def apply_model(self, x, sigma, c_crossattn=None, y=None, skip_layers=[], contro
159
159
controlnet_cond = controlnet_cond .to (dtype = x .dtype , device = x .device )
160
160
controlnet_cond = controlnet_cond .repeat (x .shape [0 ], 1 , 1 , 1 )
161
161
162
- # Some ControlNets don't use the y_cond input, so we need to check if it's needed.
163
- if y_cond .shape [- 1 ] != self .control_model .y_embedder .mlp [0 ].in_features :
162
+ # 8B ControlNets were trained with a slightly different architecture.
163
+ is_8b = y_cond .shape [- 1 ] == self .control_model .y_embedder .mlp [0 ].in_features
164
+ if not is_8b :
164
165
y_cond = self .diffusion_model .y_embedder (y )
165
- hw = x .shape [- 2 :]
166
166
167
167
x_controlnet = x
168
- # HACK
169
- # x_controlnet = torch.load("/weka/home-brianf/x_8b.pt")
170
- # controlnet_cond = torch.load("/weka/home-brianf/x_cond_8b.pt")
171
- # y_cond = torch.load("/weka/home-brianf/y_cond_8b.pt")
172
- if self .control_model .is_8b :
168
+ if is_8b :
169
+ hw = x .shape [- 2 :]
173
170
x_controlnet = self .diffusion_model .x_embedder (x ) + self .diffusion_model .cropped_pos_embed (hw )
174
- # y_cond[0] = torch.zeros_like(y_cond[0])
175
171
controlnet_hidden_states = self .control_model (
176
- x_controlnet , controlnet_cond , y_cond , 1 , sigma .to (torch .float32 )
172
+ x_controlnet , controlnet_cond , y_cond , 1 , sigma .to (torch .float32 ), is_8b
177
173
)
178
174
model_output = self .diffusion_model (
179
175
x .to (dtype ),
@@ -747,72 +743,3 @@ def encode(self, image):
747
743
std = torch .exp (0.5 * logvar )
748
744
return mean + std * torch .randn_like (mean )
749
745
750
-
751
- class DiagonalGaussianDistribution :
752
- def __init__ (self , parameters , deterministic = False , chunk_dim : int = 1 ):
753
- self .parameters = parameters
754
- self .mean , self .logvar = torch .chunk (parameters , 2 , dim = chunk_dim )
755
- self .logvar = torch .clamp (self .logvar , - 30.0 , 20.0 )
756
- self .deterministic = deterministic
757
- self .std = torch .exp (0.5 * self .logvar )
758
- self .var = torch .exp (self .logvar )
759
- if self .deterministic :
760
- self .var = self .std = torch .zeros_like (self .mean ).to (
761
- device = self .parameters .device
762
- )
763
-
764
- def sample (self ):
765
- x = self .mean + self .std * torch .randn (self .mean .shape ).to (
766
- device = self .parameters .device
767
- )
768
- return x
769
-
770
- def kl (self , other = None ):
771
- if self .deterministic :
772
- return torch .Tensor ([0.0 ])
773
- else :
774
- if other is None :
775
- return 0.5 * torch .sum (
776
- torch .pow (self .mean , 2 ) + self .var - 1.0 - self .logvar ,
777
- dim = list (range (1 , self .mean .ndim )),
778
- )
779
- else :
780
- return 0.5 * torch .sum (
781
- torch .pow (self .mean - other .mean , 2 ) / other .var
782
- + self .var / other .var
783
- - 1.0
784
- - self .logvar
785
- + other .logvar ,
786
- dim = list (range (1 , self .mean .ndim )),
787
- )
788
-
789
- def nll (self , sample , dims = [1 , 2 , 3 ]):
790
- if self .deterministic :
791
- return torch .Tensor ([0.0 ])
792
- logtwopi = np .log (2.0 * np .pi )
793
- return 0.5 * torch .sum (
794
- logtwopi + self .logvar + torch .pow (sample - self .mean , 2 ) / self .var ,
795
- dim = dims ,
796
- )
797
-
798
- def mode (self ):
799
- return self .mean
800
-
801
-
802
- class DiagonalGaussianRegularizer (nn .Module ):
803
- def __init__ (self , sample : bool = True , chunk_dim : int = 1 ):
804
- super ().__init__ ()
805
- self .sample = sample
806
- self .chunk_dim = chunk_dim
807
-
808
- def forward (self , z : torch .Tensor ) -> Tuple [torch .Tensor , dict ]:
809
- log = dict ()
810
- posterior = DiagonalGaussianDistribution (z , chunk_dim = self .chunk_dim )
811
- if self .sample :
812
- z = posterior .sample ()
813
- else :
814
- z = posterior .mode ()
815
- kl_loss = posterior .kl ()
816
- kl_loss = torch .sum (kl_loss ) / kl_loss .shape [0 ]
817
- log ["kl_loss" ] = kl_loss
818
- return z , log
0 commit comments