25
25
* yarn
26
26
* yarn train
27
27
* ```
28
+ *
29
+ * If available, a CUDA GPU will give you a higher training speed:
30
+ *
31
+ * ```sh
32
+ * yarn
33
+ * yarn train --gpu
34
+ * ```
28
35
*
29
36
* To start the demo in the browser, do in a separate terminal:
30
37
*
@@ -50,13 +57,6 @@ const fs = require('fs');
50
57
const path = require ( 'path' ) ;
51
58
52
59
const argparse = require ( 'argparse' ) ;
53
- const tf = require ( '@tensorflow/tfjs' ) ;
54
- require ( '@tensorflow/tfjs-node' ) ;
55
-
56
- // Uncomment me to train the model on GPU.
57
- // Requires: CUDA-enabled GPU, installs of CUDA toolkit and CuDNN.
58
- // require('@tensorflow/tfjs-node-gpu');
59
-
60
60
const data = require ( './data' ) ;
61
61
62
62
// Number of classes in the MNIST dataset.
@@ -65,6 +65,11 @@ const NUM_CLASSES = 10;
65
65
// MNIST image size.
66
66
const IMAGE_SIZE = 28 ;
67
67
68
+ // The value of the tf object will be set dynamically, depending on whether
69
+ // the CPU (tfjs-node) or GPU (tfjs-node-gpu) backend is used. This is why
70
+ // `let` is used in lieu of the more conventiona `const` here.
71
+ let tf = require ( '@tensorflow/tfjs' ) ;
72
+
68
73
/**
69
74
* Build the generator part of ACGAN.
70
75
*
@@ -77,7 +82,7 @@ const IMAGE_SIZE = 28;
77
82
* It generates one output: the generated (i.e., fake) image.
78
83
*
79
84
* @param {number } latentSize Size of the latent space.
80
- * @returns {tf.Model } The generator model.
85
+ * @returns {tf.LayersModel } The generator model.
81
86
*/
82
87
function buildGenerator ( latentSize ) {
83
88
tf . util . assert (
@@ -171,7 +176,7 @@ function buildGenerator(latentSize) {
171
176
* which is the discriminator's 10-class classification result
172
177
* for the input image.
173
178
*
174
- * @returns {tf.Model } The discriminator model.
179
+ * @returns {tf.LayersModel } The discriminator model.
175
180
*/
176
181
function buildDiscriminator ( ) {
177
182
const cnn = tf . sequential ( ) ;
@@ -224,9 +229,43 @@ function buildDiscriminator() {
224
229
return tf . model ( { inputs : image , outputs : [ realnessScore , aux ] } ) ;
225
230
}
226
231
232
+ /**
233
+ * Build a combined ACGAN model.
234
+ *
235
+ * @param {number } latentSize Size of the latent vector.
236
+ * @param {tf.SymbolicTensor } imageClass Symbolic tensor for the desired image
237
+ * class. This is the other input to the generator.
238
+ * @param {tf.LayersModel } generator The generator.
239
+ * @param {tf.LayersModel } discriminator The discriminator.
240
+ * @param {tf.Optimizer } optimizer The optimizer to be used for training the
241
+ * combined model.
242
+ * @returns {tf.LayersModel } The combined ACGAN model, compiled.
243
+ */
244
+ function buildCombinedModel ( latentSize , generator , discriminator , optimizer ) {
245
+ // Latent vector. This is one of the two inputs to the generator.
246
+ const latent = tf . input ( { shape : [ latentSize ] } ) ;
247
+ // Desired image class. This is the second input to the generator.
248
+ const imageClass = tf . input ( { shape : [ 1 ] } ) ;
249
+ // Get the symbolic tensor for fake images generated by the generator.
250
+ let fake = generator . apply ( [ latent , imageClass ] ) ;
251
+ let aux ;
252
+
253
+ // We only want to be able to train generation for the combined model.
254
+ discriminator . trainable = false ;
255
+ [ fake , aux ] = discriminator . apply ( fake ) ;
256
+ const combined =
257
+ tf . model ( { inputs : [ latent , imageClass ] , outputs : [ fake , aux ] } ) ;
258
+ combined . compile ( {
259
+ optimizer,
260
+ loss : [ 'binaryCrossentropy' , 'sparseCategoricalCrossentropy' ]
261
+ } ) ;
262
+ combined . summary ( ) ;
263
+ return combined ;
264
+ }
265
+
227
266
// "Soft" one used for training the combined ACGAN model.
228
267
// This is an important trick in training GANs.
229
- const softOne = tf . scalar ( 0.95 ) ;
268
+ const SOFT_ONE = 0.95 ;
230
269
231
270
/**
232
271
* Train the discriminator for one step.
@@ -252,9 +291,9 @@ const softOne = tf.scalar(0.95);
252
291
* @param {number } batchSize Size of the batch to draw from `xTrain` and
253
292
* `yTrain`.
254
293
* @param {number } latentSize Size of the latent space (z-space).
255
- * @param {tf.Model } generator The generator of the ACGAN.
256
- * @param {tf.Model } discriminator The discriminator of the ACGAN.
257
- * @returns The loss values from the one-step training as numbers.
294
+ * @param {tf.LayersModel } generator The generator of the ACGAN.
295
+ * @param {tf.LayersModel } discriminator The discriminator of the ACGAN.
296
+ * @returns { number[] } The loss values from the one-step training as numbers.
258
297
*/
259
298
async function trainDiscriminatorOneStep (
260
299
xTrain , yTrain , batchStart , batchSize , latentSize , generator ,
@@ -275,9 +314,10 @@ async function trainDiscriminatorOneStep(
275
314
generator . predict ( [ zVectors , sampledLabels ] , { batchSize : batchSize } ) ;
276
315
277
316
const x = tf . concat ( [ imageBatch , generatedImages ] , 0 ) ;
317
+
278
318
const y = tf . tidy (
279
319
( ) => tf . concat (
280
- [ tf . ones ( [ batchSize , 1 ] ) . mul ( softOne ) , tf . zeros ( [ batchSize , 1 ] ) ] ) ) ;
320
+ [ tf . ones ( [ batchSize , 1 ] ) . mul ( SOFT_ONE ) , tf . zeros ( [ batchSize , 1 ] ) ] ) ) ;
281
321
282
322
const auxY = tf . concat ( [ labelBatch , sampledLabels ] , 0 ) ;
283
323
return [ x , y , auxY ] ;
@@ -295,9 +335,9 @@ async function trainDiscriminatorOneStep(
295
335
*
296
336
* @param {number } batchSize Size of the fake-image batch to generate.
297
337
* @param {number } latentSize Size of the latent space (z-space).
298
- * @param {tf.Model } combined The instance of tf.Model that combines
338
+ * @param {tf.LayersModel } combined The instance of tf.LayersModel that combines
299
339
* the generator and the discriminator.
300
- * @returns The loss values from the combined model as numbers.
340
+ * @returns { number[] } The loss values from the combined model as numbers.
301
341
*/
302
342
async function trainCombinedModelOneStep ( batchSize , latentSize , combined ) {
303
343
// TODO(cais): Remove tidy() once the current memory leak issue in tfjs-node
@@ -312,7 +352,7 @@ async function trainCombinedModelOneStep(batchSize, latentSize, combined) {
312
352
// We want to train the generator to trick the discriminator.
313
353
// For the generator, we want all the {fake, not-fake} labels to say
314
354
// not-fake.
315
- const trick = tf . tidy ( ( ) => tf . ones ( [ batchSize , 1 ] ) . mul ( softOne ) ) ;
355
+ const trick = tf . tidy ( ( ) => tf . ones ( [ batchSize , 1 ] ) . mul ( SOFT_ONE ) ) ;
316
356
return [ zVectors , sampledLabels , trick ] ;
317
357
} ) ;
318
358
@@ -322,11 +362,15 @@ async function trainCombinedModelOneStep(batchSize, latentSize, combined) {
322
362
return losses ;
323
363
}
324
364
325
- function buildArgumentParser ( ) {
365
+ function parseArguments ( ) {
326
366
const parser = new argparse . ArgumentParser ( {
327
367
description : 'TensorFlowj.js: MNIST ACGAN trainer example.' ,
328
368
addHelp : true
329
369
} ) ;
370
+ parser . addArgument ( '--gpu' , {
371
+ action : 'storeTrue' ,
372
+ help : 'Use tfjs-node-gpu for training (required CUDA GPU)'
373
+ } ) ;
330
374
parser . addArgument (
331
375
'--epochs' ,
332
376
{ type : 'int' , defaultValue : 100 , help : 'Number of training epochs.' } ) ;
@@ -353,7 +397,11 @@ function buildArgumentParser() {
353
397
defaultValue : './dist/generator' ,
354
398
help : 'Path to which the generator model will be saved after every epoch.'
355
399
} ) ;
356
- return parser ;
400
+ parser . addArgument ( '--logDir' , {
401
+ type : 'string' ,
402
+ help : 'Optional log directory to which the loss values will be written.'
403
+ } ) ;
404
+ return parser . parseArgs ( ) ;
357
405
}
358
406
359
407
function makeMetadata ( totalEpochs , currentEpoch , completed ) {
@@ -366,8 +414,16 @@ function makeMetadata(totalEpochs, currentEpoch, completed) {
366
414
}
367
415
368
416
async function run ( ) {
369
- const parser = buildArgumentParser ( ) ;
370
- const args = parser . parseArgs ( ) ;
417
+ const args = parseArguments ( ) ;
418
+ // Set the value of tf depending on whether the CPU or GPU version of
419
+ // libtensorflow is used.
420
+ if ( args . gpu ) {
421
+ console . log ( 'Using GPU' ) ;
422
+ tf = require ( '@tensorflow/tfjs-node-gpu' ) ;
423
+ } else {
424
+ console . log ( 'Using CPU' ) ;
425
+ tf = require ( '@tensorflow/tfjs-node' ) ;
426
+ }
371
427
372
428
if ( ! fs . existsSync ( path . dirname ( args . generatorSavePath ) ) ) {
373
429
fs . mkdirSync ( path . dirname ( args . generatorSavePath ) ) ;
@@ -387,23 +443,9 @@ async function run() {
387
443
const generator = buildGenerator ( args . latentSize ) ;
388
444
generator . summary ( ) ;
389
445
390
- const latent = tf . input ( { shape : [ args . latentSize ] } ) ;
391
- const imageClass = tf . input ( { shape : [ 1 ] } ) ;
392
-
393
- // Get a fake image.
394
- let fake = generator . apply ( [ latent , imageClass ] ) ;
395
- let aux ;
396
-
397
- // We only want to be able to train generation for the combined model.
398
- discriminator . trainable = false ;
399
- [ fake , aux ] = discriminator . apply ( fake ) ;
400
- const combined =
401
- tf . model ( { inputs : [ latent , imageClass ] , outputs : [ fake , aux ] } ) ;
402
- combined . compile ( {
403
- optimizer : tf . train . adam ( args . learningRate , args . adamBeta1 ) ,
404
- loss : [ 'binaryCrossentropy' , 'sparseCategoricalCrossentropy' ]
405
- } ) ;
406
- combined . summary ( ) ;
446
+ const optimizer = tf . train . adam ( args . learningRate , args . adamBeta1 ) ;
447
+ const combined = buildCombinedModel (
448
+ args . latentSize , generator , discriminator , optimizer ) ;
407
449
408
450
await data . loadData ( ) ;
409
451
let { images : xTrain , labels : yTrain } = data . getTrainData ( ) ;
@@ -413,6 +455,13 @@ async function run() {
413
455
await generator . save ( saveURL ) ;
414
456
415
457
let numTensors ;
458
+ let logWriter ;
459
+ if ( args . logDir ) {
460
+ console . log ( `Logging to tensorboard at logdir: ${ args . logDir } ` ) ;
461
+ logWriter = tf . node . summaryFileWriter ( args . logDir ) ;
462
+ }
463
+
464
+ let step = 0 ;
416
465
for ( let epoch = 0 ; epoch < args . epochs ; ++ epoch ) {
417
466
// Write some metadata to disk at the beginning of every epoch.
418
467
fs . writeFileSync (
@@ -442,7 +491,11 @@ async function run() {
442
491
`epoch ${ epoch + 1 } /${ args . epochs } batch ${ batch + 1 } /${
443
492
numBatches } : ` +
444
493
`dLoss = ${ dLoss [ 0 ] . toFixed ( 6 ) } , gLoss = ${ gLoss [ 0 ] . toFixed ( 6 ) } ` ) ;
445
- tf . dispose ( [ dLoss , gLoss ] ) ;
494
+ if ( logWriter != null ) {
495
+ logWriter . scalar ( 'dLoss' , dLoss [ 0 ] , step ) ;
496
+ logWriter . scalar ( 'gLoss' , gLoss [ 0 ] , step ) ;
497
+ step ++ ;
498
+ }
446
499
447
500
// Assert on no memory leak.
448
501
// TODO(cais): Remove this check once the current memory leak in
@@ -463,16 +516,20 @@ async function run() {
463
516
console . log ( `Saved generator model to: ${ saveURL } \n` ) ;
464
517
}
465
518
466
- // Write metadata to disk to indicate
467
- // the end of the training.
519
+ // Write metadata to disk to indicate the end of the training.
468
520
fs . writeFileSync (
469
521
metadataPath ,
470
522
JSON . stringify ( makeMetadata ( args . epochs , args . epochs , true ) ) ) ;
471
523
}
472
524
473
- run ( ) ;
525
+ if ( require . main === module ) {
526
+ run ( ) ;
527
+ }
474
528
475
529
module . exports = {
530
+ buildCombinedModel,
476
531
buildDiscriminator,
477
- buildGenerator
532
+ buildGenerator,
533
+ trainCombinedModelOneStep,
534
+ trainDiscriminatorOneStep
478
535
} ;
0 commit comments