Skip to content

Commit 7904cd1

Browse files
authored
added code
1 parent e1a53a4 commit 7904cd1

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

GAN

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Libraries
2+
library(keras)
3+
library(EBImage)
4+
5+
# MNIST data
6+
mnist <- dataset_mnist()
7+
c(c(trainx, trainy), c(testx, testy)) %<-% mnist
8+
trainx <- trainx[trainy == 5,,]
9+
par(mfrow = c(8, 8), mar = rep(0, 4))
10+
for (i in 1:64) plot(as.raster(trainx[i,,], max = 255))
11+
trainx <- array_reshape(trainx, c(nrow(trainx), 28, 28, 1))
12+
trainx <- trainx/255
13+
14+
# Generator network
15+
h <- 28; w <- 28; c <- 1; l <- 28
16+
gi <- layer_input(shape = l)
17+
go <- gi %>% layer_dense(units = 32 * 14 * 14) %>%
18+
layer_activation_leaky_relu() %>%
19+
layer_reshape(target_shape = c(14, 14, 32)) %>%
20+
layer_conv_2d(filters = 32,
21+
kernel_size = 5,
22+
padding = "same") %>%
23+
layer_activation_leaky_relu() %>%
24+
layer_conv_2d_transpose(filter = 32,
25+
kernel_size = 4,
26+
strides = 2,
27+
padding = "same") %>%
28+
layer_activation_leaky_relu() %>%
29+
layer_conv_2d(filters = 1,
30+
kernel_size = 5,
31+
activation = "tanh",
32+
padding = "same")
33+
g <- keras_model(gi, go)
34+
35+
# Discriminator
36+
di <- layer_input(shape = c(h, w, c))
37+
do <- di %>%
38+
layer_conv_2d(filters = 64,
39+
kernel_size = 4) %>%
40+
layer_activation_leaky_relu() %>%
41+
layer_flatten() %>%
42+
layer_dropout(rate = 0.3) %>%
43+
layer_dense(units = 1,
44+
activation = "sigmoid")
45+
d <- keras_model(di, do)
46+
47+
# Compile
48+
d %>% compile(optimizer = 'rmsprop', loss = 'binary_crossentropy')
49+
50+
# Freeze weights and compile
51+
freeze_weights(d)
52+
gani <- layer_input(shape = l)
53+
gano <- gani %>% g %>% d
54+
gan <- keras_model(gani, gano)
55+
gan %>% compile(optimizer = 'rmsprop',
56+
loss = "binary_crossentropy")
57+
58+
# Training
59+
b <- 50
60+
setwd("~/Desktop/")
61+
dir <- "gan_img"
62+
dir.create(dir)
63+
start <- 1; dloss <- NULL; gloss <- NULL
64+
65+
#1. Generate 50 fake images
66+
for (i in 1:100) {noise <- matrix(rnorm(b*l),
67+
nrow = b,
68+
ncol = l)
69+
fake <- g %>% predict(noise)
70+
71+
#2. Combine real & fake
72+
stop <- start + b - 1
73+
real <- trainx[start:stop,,,]
74+
real <- array_reshape(real, c(nrow(real), 28, 28, 1 ))
75+
rows <- nrow(real)
76+
both <- array(0, dim = c(rows*2, dim(real)[-1]))
77+
both[1:rows,,,] <- fake
78+
both[(rows+1): (rows*2),,,] <- real
79+
labels <- rbind(matrix(runif(b, 0.9, 1),
80+
nrow = b,
81+
ncol = 1),
82+
matrix(runif(b, 0, 0.1),
83+
nrow = b,
84+
ncol = 1))
85+
start <- start + b
86+
87+
#3. Train discriminator
88+
dloss[i] <- d %>% train_on_batch(both, labels)
89+
90+
#4. Train generator using gan
91+
fakeAsReal <- array(runif(b, 0, 0.1), dim = c(b, 1))
92+
gloss[i] <- gan %>% train_on_batch(noise, fakeAsReal)
93+
94+
#5. Save fake images
95+
f <- fake[1,,,]
96+
dim(f) <- c(28, 28, 1)
97+
image_array_save(f, path = file.path(dir,
98+
paste0("f", i, ".png")))}
99+
100+
# Plot loss
101+
x <- 1:100
102+
plot(x, dloss, col = 'red', type = 'l',
103+
ylim = c(0, 3),
104+
xlab = 'Iterations',
105+
ylab = 'Loss')
106+
lines(x, gloss, col = 'black', type = 'l')
107+
legend('topright',
108+
legend = c("Discriminator Loss", "GAN Loss"),
109+
col = c("red", 'black'), lty = 1:2, cex = 1)
110+
111+
# 100 fake images
112+
setwd("~/Desktop/gan_img")
113+
temp = list.files(pattern = "*.png")
114+
mypic <- list()
115+
for (i in 1:length(temp)) {mypic[[i]] <- readImage(temp[[i]])}
116+
par(mfrow = c(10,10))
117+
for (i in 1:length(temp)) plot(mypic[[i]])

0 commit comments

Comments
 (0)