-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet.py
57 lines (46 loc) · 2.13 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import tensorflow as tf
class GAN(tf.keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, discriminator_optimizer, generator_optimizer, loss_fn):
super(GAN, self).compile()
self.discriminator_optimizer = discriminator_optimizer
self.generator_optimizer = generator_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(
shape=(batch_size, 1, 1, self.latent_dim)
)
generated_images = self.generator(random_latent_vectors)
combined_images = tf.concat([generated_images, real_images], axis=0)
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.discriminator_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(
self.generator(random_latent_vectors)
)
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.generator_optimizer.apply_gradients(
zip(grads, self.generator.trainable_weights)
)
return {"d_loss": d_loss, "g_loss": g_loss}