Skip to content

Commit 93c4486

Browse files
committed
fix format
1 parent 5647b04 commit 93c4486

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

_posts/2017-06-01-Generative-Adversial-Network-in-R.md

+22-13
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ require("readr")
3030
require("mxnet")
3131
```
3232

33-
The full demo is comprised of the two following scripts available on [GitHub](https://github.com/jeremiedb/gan_example):
33+
The full demo is comprised of the two following scripts available on [GitHub](https://github.com/dmlc/mxnet/tree/master/example/gan/blog_R_cgan):
3434

3535
- `CGAN_mnist_setup.R`: prepare data and define the model structure
3636
- `CGAN_train.R`: execute the training
@@ -41,7 +41,7 @@ The MNIST dataset is available on [Kaggle](https://www.kaggle.com/c/digit-recogn
4141

4242
```r
4343
train <- read_csv('data/train.csv')
44-
train<- data.matrix(train)
44+
train <- data.matrix(train)
4545

4646
train_data <- train[,-1]
4747
train_data <- t(train_data/255*2-1)
@@ -79,14 +79,18 @@ The training process of the discriminator is most obvious: the loss is simple a
7979

8080
```r
8181
### Train loop on fake
82-
mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(0, batch_size))), match.name=TRUE)
82+
mx.exec.update.arg.arrays(exec_D, arg.arrays =
83+
list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(0, batch_size))),
84+
match.name=TRUE)
8385
mx.exec.forward(exec_D, is.train=T)
8486
mx.exec.backward(exec_D)
8587
update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
8688
mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null=TRUE)
8789

8890
### Train loop on real
89-
mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data=D_data_real, digit=D_digit_real, label=mx.nd.array(rep(1, batch_size))), match.name=TRUE)
91+
mx.exec.update.arg.arrays(exec_D, arg.arrays =
92+
list(data=D_data_real, digit=D_digit_real, label=mx.nd.array(rep(1, batch_size))),
93+
match.name=TRUE)
9094
mx.exec.forward(exec_D, is.train=T)
9195
mx.exec.backward(exec_D)
9296
update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
@@ -99,15 +103,20 @@ This requires to backpropagate the gradients up to the input data of the discrim
99103

100104
```r
101105
### Update Generator weights - use a seperate executor for writing data gradients
102-
exec_D_back<- mxnet:::mx.symbol.bind(symbol = D_sym, arg.arrays = exec_D$arg.arrays, aux.arrays = exec_D$aux.arrays, grad.reqs = rep("write", length(exec_D$arg.arrays)), ctx = devices)
103-
104-
mx.exec.update.arg.arrays(exec_D_back, arg.arrays = list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(1, batch_size))), match.name=TRUE)
106+
exec_D_back <- mxnet:::mx.symbol.bind(symbol = D_sym,
107+
arg.arrays = exec_D$arg.arrays,
108+
aux.arrays = exec_D$aux.arrays, grad.reqs = rep("write", length(exec_D$arg.arrays)),
109+
ctx = devices)
110+
111+
mx.exec.update.arg.arrays(exec_D_back, arg.arrays =
112+
list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(1, batch_size))),
113+
match.name=TRUE)
105114
mx.exec.forward(exec_D_back, is.train=T)
106115
mx.exec.backward(exec_D_back)
107-
D_grads<- exec_D_back$ref.grad.arrays$data
116+
D_grads <- exec_D_back$ref.grad.arrays$data
108117
mx.exec.backward(exec_G, out_grads=D_grads)
109118

110-
update_args_G<- updater_G(weight = exec_G$ref.arg.arrays, grad = exec_G$ref.grad.arrays)
119+
update_args_G <- updater_G(weight = exec_G$ref.arg.arrays, grad = exec_G$ref.grad.arrays)
111120
mx.exec.update.arg.arrays(exec_G, update_args_G, skip.null=TRUE)
112121
```
113122

@@ -148,11 +157,11 @@ Once the model is trained, synthetic images of the desired digit can be produced
148157
Here we will generate fake "9":
149158

150159
```r
151-
digit<- mx.nd.array(rep(9, times=batch_size))
152-
data<- mx.nd.one.hot(indices = digit, depth = 10)
153-
data<- mx.nd.reshape(data = data, shape = c(1,1,-1, batch_size))
160+
digit <- mx.nd.array(rep(9, times=batch_size))
161+
data <- mx.nd.one.hot(indices = digit, depth = 10)
162+
data <- mx.nd.reshape(data = data, shape = c(1,1,-1, batch_size))
154163

155-
exec_G<- mx.simple.bind(symbol = G_sym, data=data_shape_G, ctx = devices, grad.req = "null")
164+
exec_G <- mx.simple.bind(symbol = G_sym, data=data_shape_G, ctx = devices, grad.req = "null")
156165
mx.exec.update.arg.arrays(exec_G, G_arg_params, match.name=TRUE)
157166
mx.exec.update.arg.arrays(exec_G, list(data=data), match.name=TRUE)
158167
mx.exec.update.aux.arrays(exec_G, G_aux_params, match.name=TRUE)

0 commit comments

Comments
 (0)