|
| 1 | +# Libraries |
| 2 | +library(keras) |
| 3 | +library(EBImage) |
| 4 | + |
| 5 | +# MNIST data |
| 6 | +mnist <- dataset_mnist() |
| 7 | +c(c(trainx, trainy), c(testx, testy)) %<-% mnist |
| 8 | +trainx <- trainx[trainy == 5,,] |
| 9 | +par(mfrow = c(8, 8), mar = rep(0, 4)) |
| 10 | +for (i in 1:64) plot(as.raster(trainx[i,,], max = 255)) |
| 11 | +trainx <- array_reshape(trainx, c(nrow(trainx), 28, 28, 1)) |
| 12 | +trainx <- trainx/255 |
| 13 | + |
| 14 | +# Generator network |
| 15 | +h <- 28; w <- 28; c <- 1; l <- 28 |
| 16 | +gi <- layer_input(shape = l) |
| 17 | +go <- gi %>% layer_dense(units = 32 * 14 * 14) %>% |
| 18 | + layer_activation_leaky_relu() %>% |
| 19 | + layer_reshape(target_shape = c(14, 14, 32)) %>% |
| 20 | + layer_conv_2d(filters = 32, |
| 21 | + kernel_size = 5, |
| 22 | + padding = "same") %>% |
| 23 | + layer_activation_leaky_relu() %>% |
| 24 | + layer_conv_2d_transpose(filter = 32, |
| 25 | + kernel_size = 4, |
| 26 | + strides = 2, |
| 27 | + padding = "same") %>% |
| 28 | + layer_activation_leaky_relu() %>% |
| 29 | + layer_conv_2d(filters = 1, |
| 30 | + kernel_size = 5, |
| 31 | + activation = "tanh", |
| 32 | + padding = "same") |
| 33 | +g <- keras_model(gi, go) |
| 34 | + |
| 35 | +# Discriminator |
| 36 | +di <- layer_input(shape = c(h, w, c)) |
| 37 | +do <- di %>% |
| 38 | + layer_conv_2d(filters = 64, |
| 39 | + kernel_size = 4) %>% |
| 40 | + layer_activation_leaky_relu() %>% |
| 41 | + layer_flatten() %>% |
| 42 | + layer_dropout(rate = 0.3) %>% |
| 43 | + layer_dense(units = 1, |
| 44 | + activation = "sigmoid") |
| 45 | +d <- keras_model(di, do) |
| 46 | + |
| 47 | +# Compile |
| 48 | +d %>% compile(optimizer = 'rmsprop', loss = 'binary_crossentropy') |
| 49 | + |
| 50 | +# Freeze weights and compile |
| 51 | +freeze_weights(d) |
| 52 | +gani <- layer_input(shape = l) |
| 53 | +gano <- gani %>% g %>% d |
| 54 | +gan <- keras_model(gani, gano) |
| 55 | +gan %>% compile(optimizer = 'rmsprop', |
| 56 | + loss = "binary_crossentropy") |
| 57 | + |
| 58 | +# Training |
| 59 | +b <- 50 |
| 60 | +setwd("~/Desktop/") |
| 61 | +dir <- "gan_img" |
| 62 | +dir.create(dir) |
| 63 | +start <- 1; dloss <- NULL; gloss <- NULL |
| 64 | + |
| 65 | +#1. Generate 50 fake images |
| 66 | +for (i in 1:100) {noise <- matrix(rnorm(b*l), |
| 67 | + nrow = b, |
| 68 | + ncol = l) |
| 69 | +fake <- g %>% predict(noise) |
| 70 | + |
| 71 | +#2. Combine real & fake |
| 72 | +stop <- start + b - 1 |
| 73 | +real <- trainx[start:stop,,,] |
| 74 | +real <- array_reshape(real, c(nrow(real), 28, 28, 1 )) |
| 75 | +rows <- nrow(real) |
| 76 | +both <- array(0, dim = c(rows*2, dim(real)[-1])) |
| 77 | +both[1:rows,,,] <- fake |
| 78 | +both[(rows+1): (rows*2),,,] <- real |
| 79 | +labels <- rbind(matrix(runif(b, 0.9, 1), |
| 80 | + nrow = b, |
| 81 | + ncol = 1), |
| 82 | + matrix(runif(b, 0, 0.1), |
| 83 | + nrow = b, |
| 84 | + ncol = 1)) |
| 85 | +start <- start + b |
| 86 | + |
| 87 | +#3. Train discriminator |
| 88 | +dloss[i] <- d %>% train_on_batch(both, labels) |
| 89 | + |
| 90 | +#4. Train generator using gan |
| 91 | +fakeAsReal <- array(runif(b, 0, 0.1), dim = c(b, 1)) |
| 92 | +gloss[i] <- gan %>% train_on_batch(noise, fakeAsReal) |
| 93 | + |
| 94 | +#5. Save fake images |
| 95 | +f <- fake[1,,,] |
| 96 | +dim(f) <- c(28, 28, 1) |
| 97 | +image_array_save(f, path = file.path(dir, |
| 98 | + paste0("f", i, ".png")))} |
| 99 | + |
| 100 | +# Plot loss |
| 101 | +x <- 1:100 |
| 102 | +plot(x, dloss, col = 'red', type = 'l', |
| 103 | + ylim = c(0, 3), |
| 104 | + xlab = 'Iterations', |
| 105 | + ylab = 'Loss') |
| 106 | +lines(x, gloss, col = 'black', type = 'l') |
| 107 | +legend('topright', |
| 108 | + legend = c("Discriminator Loss", "GAN Loss"), |
| 109 | + col = c("red", 'black'), lty = 1:2, cex = 1) |
| 110 | + |
| 111 | +# 100 fake images |
| 112 | +setwd("~/Desktop/gan_img") |
| 113 | +temp = list.files(pattern = "*.png") |
| 114 | +mypic <- list() |
| 115 | +for (i in 1:length(temp)) {mypic[[i]] <- readImage(temp[[i]])} |
| 116 | +par(mfrow = c(10,10)) |
| 117 | +for (i in 1:length(temp)) plot(mypic[[i]]) |
0 commit comments