-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
147 lines (116 loc) · 5.33 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm1d
from torch.nn.modules.conv import Conv2d, ConvTranspose2d
from torch.nn.modules.linear import Linear
from torchvision import datasets
'''
This generator is based on the implementations in 'Unsupervised Representation Learning
with Deep Convolutional Generative Adversarial Networks'
'''
class conv_block(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, type = ['generator', 'discriminator'], last = False):
super().__init__()
self.modules = []
if type == 'generator':
self.conv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride=stride, \
padding=padding, bias=False)
self.modules.append(self.conv)
if last:
self.act = nn.Tanh()
self.modules.append(self.act)
else:
self.bn = nn.BatchNorm2d(out_channel)
self.act = nn.ReLU()
self.modules.extend([self.bn, self.act])
if type == 'discriminator':
self.conv = Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, \
padding=padding, bias=False)
self.modules.append(self.conv)
if last:
self.act = nn.Sigmoid()
self.modules.append(self.act)
else:
self.bn = nn.BatchNorm2d(out_channel)
self.act = nn.LeakyReLU(0.2, inplace=True)
self.modules.extend([self.bn, self.act])
nn.init.normal_(self.conv.weight, mean=0, std=0.02)
self.block = nn.Sequential(*self.modules)
def forward(self, input):
return self.block(input)
class generator_net(nn.Module):
def __init__(self, in_channels, out_channels, feature_maps = 64):
super().__init__()
self.blk1 = conv_block(in_channels, feature_maps * 16, kernel_size = 4, \
stride=1, padding=0, type = 'generator')
# outputs = (N, C, 4, 4)
self.blk2 = conv_block(feature_maps * 16, feature_maps * 8, kernel_size=4, \
stride=2, padding=1, type = 'generator')
# outputs = (N, C, 8, 8)
self.blk3 = conv_block(feature_maps * 8 , feature_maps * 4, kernel_size=4, \
stride=2, padding=1, type='generator')
# outputs = (N, C, 16, 16)
self.blk4 = conv_block(feature_maps * 4 , feature_maps * 2, kernel_size=4, \
stride=2, padding=1, type='generator')
#outputs = (N, C, 32, 32)
self.blk5 = conv_block(feature_maps * 2 , out_channels, kernel_size=4, \
stride=2, padding=1, type='generator', last=True)
#outputs = (N, 3, 64, 64)
self.net = nn.Sequential(
self.blk1,
self.blk2,
self.blk3,
self.blk4,
self.blk5
)
def forward(self, input):
return self.net(input)
class discriminator_net(nn.Module):
def __init__(self, in_channel, feature_maps):
super().__init__()
self.blk1 = conv_block(in_channel, feature_maps, kernel_size=4, \
stride=2, padding=1, type='discriminator')
# (N, C, 32, 32)
self.blk2 = conv_block(feature_maps, feature_maps * 2, kernel_size=4, \
stride=2, padding=1, type='discriminator')
# (N, C, 16, 16)
self.blk3 = conv_block(feature_maps * 2, feature_maps * 4, kernel_size=4, \
stride=2, padding=1, type='discriminator')
# (N, C, 8, 8)
self.blk4 = conv_block(feature_maps * 4, 1, kernel_size=8, stride=1, padding=0, \
type='discriminator', last=True)
# self.blk4 = conv_block(feature_maps * 4, feature_maps * 8, kernel_size=4, \
# stride=1, padding=1, type='discriminator',)
# self.blk5 = conv_block(feature_maps * 8, 1, kernel_size = 4, stride=1, padding=1, \
# type='discriminator', last=True)
self.net = nn.Sequential(
self.blk1,
self.blk2,
self.blk3,
self.blk4
# self.blk5
)
def forward(self, input):
return self.net(input)
class autoencoder(nn.Module):
def __init__(self, height, width, feature_vector_length):
# init_shape = (N, C, hx, hy)
super().__init__()
self.height = height
self.width = width
self.net = nn.Sequential(
nn.Linear(height * width, 256),
nn.Dropout(0.1),
nn.ReLU(inplace=True),
nn.Linear(256, 128),
nn.Dropout(0.1),
nn.ReLU(inplace=True),
nn.Linear(128, feature_vector_length),
nn.Sigmoid()
)
def forward(self, input):
assert input.shape == (self.height, self.width), 'Wrong input image size'
input = input.flatten()
return self.net(input)