-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMSG_CGAN_Generator.py
87 lines (69 loc) · 3.21 KB
/
MSG_CGAN_Generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
class Generator(torch.nn.Module):
"""
This Generator is made up of Blocks of Transposed Convolutional Layes, with BatchNorm and LeakyReLU
"""
def __init__(self, Noise_size, Label_size, Channel_size, Picture_size):
super().__init__()
"""
"""
self.dense = torch.nn.Sequential(
torch.nn.Linear(in_features = Noise_size + Label_size , out_features= 512 *2*2),
torch.nn.BatchNorm1d(2048, momentum = 0.9),
torch.nn.LeakyReLU(0.1)
)
self.ReLU = torch.nn.ReLU()
self.conv1 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels= 512 , out_channels= 256, kernel_size=5, stride=2, padding=2, output_padding= 1),
torch.nn.BatchNorm2d(256,momentum = 0.9),
torch.nn.LeakyReLU(0.1)
)
self.out1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels = 256, out_channels = 3, kernel_size = 1, stride = 1)
)
self.conv2 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels= 256, out_channels= 128, kernel_size=5, stride=2, padding=2, output_padding= 1),
torch.nn.BatchNorm2d(128, momentum = 0.9),
torch.nn.LeakyReLU(0.1)
)
self.out2 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels = 128, out_channels = 3, kernel_size = 1, stride = 1)
)
self.conv3 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels= 128, out_channels= 64, kernel_size=5, stride=2, padding=2, output_padding= 1),
torch.nn.BatchNorm2d(64, momentum = 0.9),
torch.nn.LeakyReLU(0.1)
)
self.out3 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels = 64, out_channels = 3, kernel_size = 1, stride = 1)
)
self.conv4 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels= 64, out_channels= 3, kernel_size=5, stride=2, padding=2, output_padding= 1),
)
self.Label_size = Label_size
def forward(self, noise, labels):
"""
noise has shape (BATCHSIZE, 1)
labels has shape (BATCHSIZE,)
"""
labels = torch.nn.functional.one_hot(labels, num_classes = self.Label_size) #shape (BATCH_SIZE, 10)
temp = torch.cat((noise, labels.float()), 1) #shape (BATCH_SIZE, NOISE_SIZE)
temp = self.dense(temp) #shape (BATCH_SIZE, 2048)
#temp = self.ReLU(temp)
temp = torch.reshape(temp, (-1, 512,2,2)) # shape (BATCH_SIZE, 512,2,2)
temp = self.conv1(temp) # shape (BATCH_SIZE, 256, 4, 4)
out1 = self.out1(temp)
temp = self.conv2(temp) # shape (BATCH_SIZE, 128, 8, 8)
out2 = self.out2(temp)
temp = self.conv3(temp) # shape (BATCH_SIZE, 64, 16, 16)
out3 = self.out3(temp)
out4 = self.conv4(temp) # shape (BATCH_SIZE, 3, 32, 32)
out1 = 0.5 * (1+ torch.tanh(out1))
out2 = 0.5 * (1+ torch.tanh(out2))
out3 = 0.5 * (1+ torch.tanh(out3))
out4 = 0.5 * (1+ torch.tanh(out4))
return out1, out2, out3, out4 # kleinstes zuerst