@@ -135,7 +135,7 @@ class AAEModel(keras.models.Model):
135
135
def __init__ (self ,molecule_shape ,latent_dim = 2 ,
136
136
enc_layers = 2 ,enc_seed = 64 ,
137
137
disc_layers = 2 ,disc_seed = 64 ,
138
- prior = tfp .distributions .Normal (loc = 0 , scale = 1 ),hp = _default_hp ,with_density = False ):
138
+ prior = tfp .distributions .Normal (loc = 0 , scale = 1 ),hp = _default_hp ,with_density = False , with_cv1_bias = False ):
139
139
super ().__init__ ()
140
140
141
141
self .hp = hp
@@ -150,6 +150,9 @@ def __init__(self,molecule_shape,latent_dim=2,
150
150
self .get_prior = _PriorImage (latent_dim ,prior )
151
151
152
152
self .with_density = with_density
153
+ self .with_cv1_bias = with_cv1_bias
154
+
155
+ assert not (with_density and with_cv1_bias )
153
156
154
157
self .enc_seed = enc_seed
155
158
self .disc_seed = disc_seed
@@ -285,6 +288,8 @@ def train_step(self,in_batch):
285
288
cheat_grads = ctape .gradient (cheat_loss ,self .enc .trainable_weights )
286
289
self .optimizer .apply_gradients (zip (cheat_grads ,self .enc .trainable_weights ))
287
290
291
+ dens_loss = 42.
292
+
288
293
# FOLLOW DENSITIES
289
294
if self .with_density :
290
295
with tf .GradientTape () as detape :
@@ -298,11 +303,15 @@ def train_step(self,in_batch):
298
303
dens_grads = detape .gradient (dens_loss ,self .enc .trainable_weights )
299
304
self .optimizer .apply_gradients (zip (dens_grads ,self .enc .trainable_weights ))
300
305
301
- else :
302
- dens_loss = 42.
306
+ # BIAS CV1
307
+ if self .with_cv1_bias :
308
+ with tf .GradientTape () as btape :
309
+ lows = self .enc (batch )
310
+ dens_loss = self .dens_loss_fn (in_batch [1 ],lows [:,0 ])
303
311
312
+ bias_grads = btape .gradient (dens_loss ,self .enc .trainable_weights )
313
+ self .optimizer .apply_gradients (zip (bias_grads ,self .enc .trainable_weights ))
304
314
305
-
306
315
307
316
return {
308
317
'AE loss min' : tf .reduce_min (ae_multiloss ),
0 commit comments