Skip to content

Commit 4687027

Browse files
authored
Add files via upload (pclubiitk#5)
1 parent 064d2ba commit 4687027

20 files changed

+517
-0
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# PyTorch Implementation of VAEGAN
2+
3+
## Usage
4+
5+
```bash
6+
$ python3 main.py --ndata 'mnist' --epochs 50
7+
```
8+
> **_NOTE:_** on Colab Notebook use following command:
9+
```python
10+
!git clone link-to-repo
11+
%run main.py --ndata 'cifar10' --epochs 50
12+
```
13+
14+
## Contributed by:
15+
* [Rishabh Dugaye](https://github.com/rishabhd786)
16+
17+
## References
18+
19+
* **Title**: Autoencoding beyond pixels using a learned similarity metric
20+
* **Authors**: Anders Boesen Lindbo Larsen,Søren Kaae Sønderby,Hugo Larochelle,Ole Winther
21+
* **Link**: https://arxiv.org/pdf/1512.09300.pdf
22+
* **Tags**: Neural Network, Generative Networks, GANs
23+
* **Year**: 2016
24+
25+
26+
# Summary
27+
28+
## Introduction
29+
30+
* The paper combine VAEs and GANs into an unsupervised generative model that simultaneously learns to encode, generate and compare dataset samples.
31+
32+
* It shows that generative models trained with learned similarity measures produce better image samples than models trained with element-wise error measures.
33+
34+
* It demonstrate that unsupervised training results in a latent image representation with disentangled factors of variation (Bengio et al., 2013). This is illustrated in experiments on a dataset of face images labelled with visual attribute vectors, where it is shown that simple arithmetic applied in the learned latent space produces images that reflect changes in these attributes.
35+
36+
## Variational Autoencoder
37+
38+
A VAE consists of two networks that encode a data samplex to a latent representation z and decode the latent representation back to data space, respectively:
39+
40+
![1](./assets/vae.png)
41+
42+
The VAE regularizes the encoder by imposing a prior over the latent distribution p(z). Typically z ∼ N (0, I) is chosen. The VAE loss is minus the sum of the expected log likelihood (the reconstruction error) and a prior regularization term:
43+
44+
![2](./assets/l_vae.png)
45+
46+
![3](./assets/vae1.png)
47+
48+
## Generative Adversarial Network
49+
50+
A GAN consists of two networks: the generator network Gen(z) maps latents z to data space while the discriminator network assigns probability y = Dis(x) ∈ [0, 1] that x is an actual training sample and probability 1 − y that x is generated by our model through x = Gen(z) with z ∼ p(z). The GAN objective is to find the binary classifier that gives the best possible discrimination between true and generated data and simultaneously encouraging Gen to fit the true data distribution. We thus aim to maximize/minimize the binary cross entropy with respect to Dis / Gen with x being a training sample
51+
and z ∼ p(z).
52+
53+
54+
![4](./assets/gan_l.png)
55+
56+
## Beyond element-wise reconstruction error with VAE/GAN
57+
58+
Specifically, since element-wise reconstruction errors are not adequate for images and other signals with invariances,
59+
we propose replacing the VAE reconstruction (expected log likelihood) error term with a reconstruction error expressed in the GAN discriminator.. To achieve this,let Disl(x) denote the hidden representation of the lth layer of the discriminator. We introduce a Gaussian observation model for Disl(x) with mean Disl(x˜) and identity covariance.We train our combined model with the triple criterion:
60+
61+
![4](./assets/tripl.png)
62+
63+
Notably, we optimize the VAE wrt Lgan which we regard as a style error in addition to the reconstruction error which
64+
can be interpreted as a content error using the terminology from Gatys et al. (2015). Moreover, since both Dec and
65+
Gen map from z to x, we share the parameters between the two.
66+
67+
68+
![4](./assets/model.png)
69+
70+
### Algorithm used for training
71+
72+
73+
![4](./assets/algo.png)
74+
75+
76+
## Implementation and Model Architecture:
77+
78+
For all our experiments, we use convolutional architectures and use backward convolution (aka.fractional striding) with stride 2 to upscale images in Dec. Backward convolution is achieved by flipping the convolution direction such that striding causes upsampling. Our models are trained with RMSProp using a learning rate of 0.0003 and a batch size of 64.
79+
80+
### Encoder
81+
82+
```
83+
----------------------------------------------------------------
84+
Layer (type) Output Shape Param #
85+
================================================================
86+
Conv2d-1 [64, 64, 32, 32] 1,664
87+
BatchNorm2d-2 [64, 64, 32, 32] 128
88+
LeakyReLU-3 [64, 64, 32, 32] 0
89+
Conv2d-4 [64, 128, 16, 16] 204,928
90+
BatchNorm2d-5 [64, 128, 16, 16] 256
91+
LeakyReLU-6 [64, 128, 16, 16] 0
92+
Conv2d-7 [64, 256, 8, 8] 819,456
93+
BatchNorm2d-8 [64, 256, 8, 8] 512
94+
LeakyReLU-9 [64, 256, 8, 8] 0
95+
Linear-10 [64, 2048] 33,556,480
96+
BatchNorm1d-11 [64, 2048] 4,096
97+
LeakyReLU-12 [64, 2048] 0
98+
Linear-13 [64, 128] 262,272
99+
Linear-14 [64, 128] 262,272
100+
================================================================
101+
Total params: 35,112,064
102+
Trainable params: 35,112,064
103+
Non-trainable params: 0
104+
----------------------------------------------------------------
105+
Input size (MB): 1.00
106+
Forward/backward pass size (MB): 171.12
107+
Params size (MB): 133.94
108+
Estimated Total Size (MB): 306.07
109+
----------------------------------------------------------------
110+
```
111+
### Discriminator
112+
113+
```
114+
----------------------------------------------------------------
115+
Layer (type) Output Shape Param #
116+
================================================================
117+
Conv2d-1 [64, 32, 64, 64] 832
118+
LeakyReLU-2 [64, 32, 64, 64] 0
119+
Conv2d-3 [64, 128, 32, 32] 102,528
120+
BatchNorm2d-4 [64, 128, 32, 32] 256
121+
LeakyReLU-5 [64, 128, 32, 32] 0
122+
Conv2d-6 [64, 256, 16, 16] 819,456
123+
BatchNorm2d-7 [64, 256, 16, 16] 512
124+
LeakyReLU-8 [64, 256, 16, 16] 0
125+
Conv2d-9 [64, 256, 8, 8] 1,638,656
126+
BatchNorm2d-10 [64, 256, 8, 8] 512
127+
LeakyReLU-11 [64, 256, 8, 8] 0
128+
Linear-12 [64, 512] 8,389,120
129+
BatchNorm1d-13 [64, 512] 1,024
130+
LeakyReLU-14 [64, 512] 0
131+
Linear-15 [64, 1] 513
132+
Sigmoid-16 [64, 1] 0
133+
================================================================
134+
Total params: 10,953,409
135+
Trainable params: 10,953,409
136+
Non-trainable params: 0
137+
----------------------------------------------------------------
138+
Input size (MB): 1.00
139+
Forward/backward pass size (MB): 440.75
140+
Params size (MB): 41.78
141+
Estimated Total Size (MB): 483.53
142+
----------------------------------------------------------------
143+
144+
```
145+
## Decoder
146+
147+
```
148+
----------------------------------------------------------------
149+
Layer (type) Output Shape Param #
150+
================================================================
151+
Linear-16 [64, 16384] 2,113,536
152+
BatchNorm1d-17 [64, 16384] 32,768
153+
LeakyReLU-18 [64, 16384] 0
154+
ConvTranspose2d-19 [64, 256, 16, 16] 2,359,552
155+
BatchNorm2d-20 [64, 256, 16, 16] 512
156+
LeakyReLU-21 [64, 256, 16, 16] 0
157+
ConvTranspose2d-22 [64, 128, 32, 32] 1,179,776
158+
BatchNorm2d-23 [64, 128, 32, 32] 256
159+
LeakyReLU-24 [64, 128, 32, 32] 0
160+
ConvTranspose2d-25 [64, 32, 64, 64] 147,488
161+
BatchNorm2d-26 [64, 32, 64, 64] 64
162+
LeakyReLU-27 [64, 32, 64, 64] 0
163+
ConvTranspose2d-28 [64, 1, 64, 64] 801
164+
Tanh-29 [64, 1, 64, 64] 0
165+
```
166+
# Results
167+
168+
## Generated images after 25 epochs(MNIST)
169+
170+
![4](./assets/MNISTrec_noise_epoch_24.png.png)
171+
172+
## Reconstructed images after 25 epochs(MNIST)
173+
174+
![4](./assets/MNISTrec_epoch_24.png.png)
175+
176+
## Generated images after 30 epochs(CIFAR10)
177+
178+
![4](./assets/rec_epoch_28.png.png)
179+
180+
## Reconstructed images after 30 epochs(CIFAR10)
181+
182+
![4](./assets/rec_epoch_32.png.png)
183+
184+
## Plot of Prior Loss vs Iterations
185+
186+
![4](./assets/kl_divergence.png)
187+
188+
## Plot of GAN Loss vs Iterations
189+
190+
![4](./assets/gan_loss.png)
191+
192+
## Plot of Reconstruction Loss vs Iterations
193+
194+
![4](./assets/recon_loss.png)
195+
196+
197+
198+
199+
200+
201+
202+
Loading
Loading
84.1 KB
Loading
Loading
Loading
Loading
Loading
Loading
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
generated images and plots of losses

0 commit comments

Comments
 (0)