@@ -30,7 +30,7 @@ require("readr")
30
30
require(" mxnet" )
31
31
```
32
32
33
- The full demo is comprised of the two following scripts available on [ GitHub] ( https://github.com/jeremiedb/gan_example ) :
33
+ The full demo is comprised of the two following scripts available on [ GitHub] ( https://github.com/dmlc/mxnet/tree/master/example/gan/blog_R_cgan ) :
34
34
35
35
- ` CGAN_mnist_setup.R ` : prepare data and define the model structure
36
36
- ` CGAN_train.R ` : execute the training
@@ -41,7 +41,7 @@ The MNIST dataset is available on [Kaggle](https://www.kaggle.com/c/digit-recogn
41
41
42
42
``` r
43
43
train <- read_csv(' data/train.csv' )
44
- train <- data.matrix(train )
44
+ train <- data.matrix(train )
45
45
46
46
train_data <- train [,- 1 ]
47
47
train_data <- t(train_data / 255 * 2 - 1 )
@@ -79,14 +79,18 @@ The training process of the discriminator is most obvious: the loss is simple a
79
79
80
80
``` r
81
81
# ## Train loop on fake
82
- mx.exec.update.arg.arrays(exec_D , arg.arrays = list (data = D_data_fake , digit = D_digit_fake , label = mx.nd.array(rep(0 , batch_size ))), match.name = TRUE )
82
+ mx.exec.update.arg.arrays(exec_D , arg.arrays =
83
+ list (data = D_data_fake , digit = D_digit_fake , label = mx.nd.array(rep(0 , batch_size ))),
84
+ match.name = TRUE )
83
85
mx.exec.forward(exec_D , is.train = T )
84
86
mx.exec.backward(exec_D )
85
87
update_args_D <- updater_D(weight = exec_D $ ref.arg.arrays , grad = exec_D $ ref.grad.arrays )
86
88
mx.exec.update.arg.arrays(exec_D , update_args_D , skip.null = TRUE )
87
89
88
90
# ## Train loop on real
89
- mx.exec.update.arg.arrays(exec_D , arg.arrays = list (data = D_data_real , digit = D_digit_real , label = mx.nd.array(rep(1 , batch_size ))), match.name = TRUE )
91
+ mx.exec.update.arg.arrays(exec_D , arg.arrays =
92
+ list (data = D_data_real , digit = D_digit_real , label = mx.nd.array(rep(1 , batch_size ))),
93
+ match.name = TRUE )
90
94
mx.exec.forward(exec_D , is.train = T )
91
95
mx.exec.backward(exec_D )
92
96
update_args_D <- updater_D(weight = exec_D $ ref.arg.arrays , grad = exec_D $ ref.grad.arrays )
@@ -99,15 +103,20 @@ This requires to backpropagate the gradients up to the input data of the discrim
99
103
100
104
``` r
101
105
# ## Update Generator weights - use a seperate executor for writing data gradients
102
- exec_D_back <- mxnet ::: mx.symbol.bind(symbol = D_sym , arg.arrays = exec_D $ arg.arrays , aux.arrays = exec_D $ aux.arrays , grad.reqs = rep(" write" , length(exec_D $ arg.arrays )), ctx = devices )
103
-
104
- mx.exec.update.arg.arrays(exec_D_back , arg.arrays = list (data = D_data_fake , digit = D_digit_fake , label = mx.nd.array(rep(1 , batch_size ))), match.name = TRUE )
106
+ exec_D_back <- mxnet ::: mx.symbol.bind(symbol = D_sym ,
107
+ arg.arrays = exec_D $ arg.arrays ,
108
+ aux.arrays = exec_D $ aux.arrays , grad.reqs = rep(" write" , length(exec_D $ arg.arrays )),
109
+ ctx = devices )
110
+
111
+ mx.exec.update.arg.arrays(exec_D_back , arg.arrays =
112
+ list (data = D_data_fake , digit = D_digit_fake , label = mx.nd.array(rep(1 , batch_size ))),
113
+ match.name = TRUE )
105
114
mx.exec.forward(exec_D_back , is.train = T )
106
115
mx.exec.backward(exec_D_back )
107
- D_grads <- exec_D_back $ ref.grad.arrays $ data
116
+ D_grads <- exec_D_back $ ref.grad.arrays $ data
108
117
mx.exec.backward(exec_G , out_grads = D_grads )
109
118
110
- update_args_G <- updater_G(weight = exec_G $ ref.arg.arrays , grad = exec_G $ ref.grad.arrays )
119
+ update_args_G <- updater_G(weight = exec_G $ ref.arg.arrays , grad = exec_G $ ref.grad.arrays )
111
120
mx.exec.update.arg.arrays(exec_G , update_args_G , skip.null = TRUE )
112
121
```
113
122
@@ -148,11 +157,11 @@ Once the model is trained, synthetic images of the desired digit can be produced
148
157
Here we will generate fake "9":
149
158
150
159
``` r
151
- digit <- mx.nd.array(rep(9 , times = batch_size ))
152
- data <- mx.nd.one.hot(indices = digit , depth = 10 )
153
- data <- mx.nd.reshape(data = data , shape = c(1 ,1 ,- 1 , batch_size ))
160
+ digit <- mx.nd.array(rep(9 , times = batch_size ))
161
+ data <- mx.nd.one.hot(indices = digit , depth = 10 )
162
+ data <- mx.nd.reshape(data = data , shape = c(1 ,1 ,- 1 , batch_size ))
154
163
155
- exec_G <- mx.simple.bind(symbol = G_sym , data = data_shape_G , ctx = devices , grad.req = " null" )
164
+ exec_G <- mx.simple.bind(symbol = G_sym , data = data_shape_G , ctx = devices , grad.req = " null" )
156
165
mx.exec.update.arg.arrays(exec_G , G_arg_params , match.name = TRUE )
157
166
mx.exec.update.arg.arrays(exec_G , list (data = data ), match.name = TRUE )
158
167
mx.exec.update.aux.arrays(exec_G , G_aux_params , match.name = TRUE )
0 commit comments