Skip to content

Commit 1fef972

Browse files
committed
Merge branch 'master' of github.com:ljocha/ASMSA
2 parents 4d52a55 + 97e0d0e commit 1fef972

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/asmsa/aae_model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class AAEModel(keras.models.Model):
135135
def __init__(self,molecule_shape,latent_dim=2,
136136
enc_layers=2,enc_seed=64,
137137
disc_layers=2,disc_seed=64,
138-
prior=tfp.distributions.Normal(loc=0, scale=1),hp=_default_hp,with_density=False):
138+
prior=tfp.distributions.Normal(loc=0, scale=1),hp=_default_hp,with_density=False,with_cv1_bias=False):
139139
super().__init__()
140140

141141
self.hp = hp
@@ -150,6 +150,9 @@ def __init__(self,molecule_shape,latent_dim=2,
150150
self.get_prior = _PriorImage(latent_dim,prior)
151151

152152
self.with_density = with_density
153+
self.with_cv1_bias = with_cv1_bias
154+
155+
assert not (with_density and with_cv1_bias)
153156

154157
self.enc_seed = enc_seed
155158
self.disc_seed = disc_seed
@@ -285,6 +288,8 @@ def train_step(self,in_batch):
285288
cheat_grads = ctape.gradient(cheat_loss,self.enc.trainable_weights)
286289
self.optimizer.apply_gradients(zip(cheat_grads,self.enc.trainable_weights))
287290

291+
dens_loss = 42.
292+
288293
# FOLLOW DENSITIES
289294
if self.with_density:
290295
with tf.GradientTape() as detape:
@@ -298,11 +303,15 @@ def train_step(self,in_batch):
298303
dens_grads = detape.gradient(dens_loss,self.enc.trainable_weights)
299304
self.optimizer.apply_gradients(zip(dens_grads,self.enc.trainable_weights))
300305

301-
else:
302-
dens_loss = 42.
306+
# BIAS CV1
307+
if self.with_cv1_bias:
308+
with tf.GradientTape() as btape:
309+
lows = self.enc(batch)
310+
dens_loss = self.dens_loss_fn(in_batch[1],lows[:,0])
303311

312+
bias_grads = btape.gradient(dens_loss,self.enc.trainable_weights)
313+
self.optimizer.apply_gradients(zip(bias_grads,self.enc.trainable_weights))
304314

305-
306315

307316
return {
308317
'AE loss min' : tf.reduce_min(ae_multiloss),

0 commit comments

Comments
 (0)