2525 * yarn
2626 * yarn train
2727 * ```
28+ *
29+ * If available, a CUDA GPU will give you a higher training speed:
30+ *
31+ * ```sh
32+ * yarn
33+ * yarn train --gpu
34+ * ```
2835 *
2936 * To start the demo in the browser, do in a separate terminal:
3037 *
@@ -50,13 +57,6 @@ const fs = require('fs');
5057const path = require ( 'path' ) ;
5158
5259const argparse = require ( 'argparse' ) ;
53- const tf = require ( '@tensorflow/tfjs' ) ;
54- require ( '@tensorflow/tfjs-node' ) ;
55-
56- // Uncomment me to train the model on GPU.
57- // Requires: CUDA-enabled GPU, installs of CUDA toolkit and CuDNN.
58- // require('@tensorflow/tfjs-node-gpu');
59-
6060const data = require ( './data' ) ;
6161
6262// Number of classes in the MNIST dataset.
@@ -65,6 +65,11 @@ const NUM_CLASSES = 10;
6565// MNIST image size.
6666const IMAGE_SIZE = 28 ;
6767
68+ // The value of the tf object will be set dynamically, depending on whether
69+ // the CPU (tfjs-node) or GPU (tfjs-node-gpu) backend is used. This is why
70+ // `let` is used in lieu of the more conventiona `const` here.
71+ let tf = require ( '@tensorflow/tfjs' ) ;
72+
6873/**
6974 * Build the generator part of ACGAN.
7075 *
@@ -77,7 +82,7 @@ const IMAGE_SIZE = 28;
7782 * It generates one output: the generated (i.e., fake) image.
7883 *
7984 * @param {number } latentSize Size of the latent space.
80- * @returns {tf.Model } The generator model.
85+ * @returns {tf.LayersModel } The generator model.
8186 */
8287function buildGenerator ( latentSize ) {
8388 tf . util . assert (
@@ -171,7 +176,7 @@ function buildGenerator(latentSize) {
171176 * which is the discriminator's 10-class classification result
172177 * for the input image.
173178 *
174- * @returns {tf.Model } The discriminator model.
179+ * @returns {tf.LayersModel } The discriminator model.
175180 */
176181function buildDiscriminator ( ) {
177182 const cnn = tf . sequential ( ) ;
@@ -224,9 +229,43 @@ function buildDiscriminator() {
224229 return tf . model ( { inputs : image , outputs : [ realnessScore , aux ] } ) ;
225230}
226231
232+ /**
233+ * Build a combined ACGAN model.
234+ *
235+ * @param {number } latentSize Size of the latent vector.
236+ * @param {tf.SymbolicTensor } imageClass Symbolic tensor for the desired image
237+ * class. This is the other input to the generator.
238+ * @param {tf.LayersModel } generator The generator.
239+ * @param {tf.LayersModel } discriminator The discriminator.
240+ * @param {tf.Optimizer } optimizer The optimizer to be used for training the
241+ * combined model.
242+ * @returns {tf.LayersModel } The combined ACGAN model, compiled.
243+ */
244+ function buildCombinedModel ( latentSize , generator , discriminator , optimizer ) {
245+ // Latent vector. This is one of the two inputs to the generator.
246+ const latent = tf . input ( { shape : [ latentSize ] } ) ;
247+ // Desired image class. This is the second input to the generator.
248+ const imageClass = tf . input ( { shape : [ 1 ] } ) ;
249+ // Get the symbolic tensor for fake images generated by the generator.
250+ let fake = generator . apply ( [ latent , imageClass ] ) ;
251+ let aux ;
252+
253+ // We only want to be able to train generation for the combined model.
254+ discriminator . trainable = false ;
255+ [ fake , aux ] = discriminator . apply ( fake ) ;
256+ const combined =
257+ tf . model ( { inputs : [ latent , imageClass ] , outputs : [ fake , aux ] } ) ;
258+ combined . compile ( {
259+ optimizer,
260+ loss : [ 'binaryCrossentropy' , 'sparseCategoricalCrossentropy' ]
261+ } ) ;
262+ combined . summary ( ) ;
263+ return combined ;
264+ }
265+
227266// "Soft" one used for training the combined ACGAN model.
228267// This is an important trick in training GANs.
229- const softOne = tf . scalar ( 0.95 ) ;
268+ const SOFT_ONE = 0.95 ;
230269
231270/**
232271 * Train the discriminator for one step.
@@ -252,9 +291,9 @@ const softOne = tf.scalar(0.95);
252291 * @param {number } batchSize Size of the batch to draw from `xTrain` and
253292 * `yTrain`.
254293 * @param {number } latentSize Size of the latent space (z-space).
255- * @param {tf.Model } generator The generator of the ACGAN.
256- * @param {tf.Model } discriminator The discriminator of the ACGAN.
257- * @returns The loss values from the one-step training as numbers.
294+ * @param {tf.LayersModel } generator The generator of the ACGAN.
295+ * @param {tf.LayersModel } discriminator The discriminator of the ACGAN.
296+ * @returns { number[] } The loss values from the one-step training as numbers.
258297 */
259298async function trainDiscriminatorOneStep (
260299 xTrain , yTrain , batchStart , batchSize , latentSize , generator ,
@@ -275,9 +314,10 @@ async function trainDiscriminatorOneStep(
275314 generator . predict ( [ zVectors , sampledLabels ] , { batchSize : batchSize } ) ;
276315
277316 const x = tf . concat ( [ imageBatch , generatedImages ] , 0 ) ;
317+
278318 const y = tf . tidy (
279319 ( ) => tf . concat (
280- [ tf . ones ( [ batchSize , 1 ] ) . mul ( softOne ) , tf . zeros ( [ batchSize , 1 ] ) ] ) ) ;
320+ [ tf . ones ( [ batchSize , 1 ] ) . mul ( SOFT_ONE ) , tf . zeros ( [ batchSize , 1 ] ) ] ) ) ;
281321
282322 const auxY = tf . concat ( [ labelBatch , sampledLabels ] , 0 ) ;
283323 return [ x , y , auxY ] ;
@@ -295,9 +335,9 @@ async function trainDiscriminatorOneStep(
295335 *
296336 * @param {number } batchSize Size of the fake-image batch to generate.
297337 * @param {number } latentSize Size of the latent space (z-space).
298- * @param {tf.Model } combined The instance of tf.Model that combines
338+ * @param {tf.LayersModel } combined The instance of tf.LayersModel that combines
299339 * the generator and the discriminator.
300- * @returns The loss values from the combined model as numbers.
340+ * @returns { number[] } The loss values from the combined model as numbers.
301341 */
302342async function trainCombinedModelOneStep ( batchSize , latentSize , combined ) {
303343 // TODO(cais): Remove tidy() once the current memory leak issue in tfjs-node
@@ -312,7 +352,7 @@ async function trainCombinedModelOneStep(batchSize, latentSize, combined) {
312352 // We want to train the generator to trick the discriminator.
313353 // For the generator, we want all the {fake, not-fake} labels to say
314354 // not-fake.
315- const trick = tf . tidy ( ( ) => tf . ones ( [ batchSize , 1 ] ) . mul ( softOne ) ) ;
355+ const trick = tf . tidy ( ( ) => tf . ones ( [ batchSize , 1 ] ) . mul ( SOFT_ONE ) ) ;
316356 return [ zVectors , sampledLabels , trick ] ;
317357 } ) ;
318358
@@ -322,11 +362,15 @@ async function trainCombinedModelOneStep(batchSize, latentSize, combined) {
322362 return losses ;
323363}
324364
325- function buildArgumentParser ( ) {
365+ function parseArguments ( ) {
326366 const parser = new argparse . ArgumentParser ( {
327367 description : 'TensorFlowj.js: MNIST ACGAN trainer example.' ,
328368 addHelp : true
329369 } ) ;
370+ parser . addArgument ( '--gpu' , {
371+ action : 'storeTrue' ,
372+ help : 'Use tfjs-node-gpu for training (required CUDA GPU)'
373+ } ) ;
330374 parser . addArgument (
331375 '--epochs' ,
332376 { type : 'int' , defaultValue : 100 , help : 'Number of training epochs.' } ) ;
@@ -353,7 +397,11 @@ function buildArgumentParser() {
353397 defaultValue : './dist/generator' ,
354398 help : 'Path to which the generator model will be saved after every epoch.'
355399 } ) ;
356- return parser ;
400+ parser . addArgument ( '--logDir' , {
401+ type : 'string' ,
402+ help : 'Optional log directory to which the loss values will be written.'
403+ } ) ;
404+ return parser . parseArgs ( ) ;
357405}
358406
359407function makeMetadata ( totalEpochs , currentEpoch , completed ) {
@@ -366,8 +414,16 @@ function makeMetadata(totalEpochs, currentEpoch, completed) {
366414}
367415
368416async function run ( ) {
369- const parser = buildArgumentParser ( ) ;
370- const args = parser . parseArgs ( ) ;
417+ const args = parseArguments ( ) ;
418+ // Set the value of tf depending on whether the CPU or GPU version of
419+ // libtensorflow is used.
420+ if ( args . gpu ) {
421+ console . log ( 'Using GPU' ) ;
422+ tf = require ( '@tensorflow/tfjs-node-gpu' ) ;
423+ } else {
424+ console . log ( 'Using CPU' ) ;
425+ tf = require ( '@tensorflow/tfjs-node' ) ;
426+ }
371427
372428 if ( ! fs . existsSync ( path . dirname ( args . generatorSavePath ) ) ) {
373429 fs . mkdirSync ( path . dirname ( args . generatorSavePath ) ) ;
@@ -387,23 +443,9 @@ async function run() {
387443 const generator = buildGenerator ( args . latentSize ) ;
388444 generator . summary ( ) ;
389445
390- const latent = tf . input ( { shape : [ args . latentSize ] } ) ;
391- const imageClass = tf . input ( { shape : [ 1 ] } ) ;
392-
393- // Get a fake image.
394- let fake = generator . apply ( [ latent , imageClass ] ) ;
395- let aux ;
396-
397- // We only want to be able to train generation for the combined model.
398- discriminator . trainable = false ;
399- [ fake , aux ] = discriminator . apply ( fake ) ;
400- const combined =
401- tf . model ( { inputs : [ latent , imageClass ] , outputs : [ fake , aux ] } ) ;
402- combined . compile ( {
403- optimizer : tf . train . adam ( args . learningRate , args . adamBeta1 ) ,
404- loss : [ 'binaryCrossentropy' , 'sparseCategoricalCrossentropy' ]
405- } ) ;
406- combined . summary ( ) ;
446+ const optimizer = tf . train . adam ( args . learningRate , args . adamBeta1 ) ;
447+ const combined = buildCombinedModel (
448+ args . latentSize , generator , discriminator , optimizer ) ;
407449
408450 await data . loadData ( ) ;
409451 let { images : xTrain , labels : yTrain } = data . getTrainData ( ) ;
@@ -413,6 +455,13 @@ async function run() {
413455 await generator . save ( saveURL ) ;
414456
415457 let numTensors ;
458+ let logWriter ;
459+ if ( args . logDir ) {
460+ console . log ( `Logging to tensorboard at logdir: ${ args . logDir } ` ) ;
461+ logWriter = tf . node . summaryFileWriter ( args . logDir ) ;
462+ }
463+
464+ let step = 0 ;
416465 for ( let epoch = 0 ; epoch < args . epochs ; ++ epoch ) {
417466 // Write some metadata to disk at the beginning of every epoch.
418467 fs . writeFileSync (
@@ -442,7 +491,11 @@ async function run() {
442491 `epoch ${ epoch + 1 } /${ args . epochs } batch ${ batch + 1 } /${
443492 numBatches } : ` +
444493 `dLoss = ${ dLoss [ 0 ] . toFixed ( 6 ) } , gLoss = ${ gLoss [ 0 ] . toFixed ( 6 ) } ` ) ;
445- tf . dispose ( [ dLoss , gLoss ] ) ;
494+ if ( logWriter != null ) {
495+ logWriter . scalar ( 'dLoss' , dLoss [ 0 ] , step ) ;
496+ logWriter . scalar ( 'gLoss' , gLoss [ 0 ] , step ) ;
497+ step ++ ;
498+ }
446499
447500 // Assert on no memory leak.
448501 // TODO(cais): Remove this check once the current memory leak in
@@ -463,16 +516,20 @@ async function run() {
463516 console . log ( `Saved generator model to: ${ saveURL } \n` ) ;
464517 }
465518
466- // Write metadata to disk to indicate
467- // the end of the training.
519+ // Write metadata to disk to indicate the end of the training.
468520 fs . writeFileSync (
469521 metadataPath ,
470522 JSON . stringify ( makeMetadata ( args . epochs , args . epochs , true ) ) ) ;
471523}
472524
473- run ( ) ;
525+ if ( require . main === module ) {
526+ run ( ) ;
527+ }
474528
475529module . exports = {
530+ buildCombinedModel,
476531 buildDiscriminator,
477- buildGenerator
532+ buildGenerator,
533+ trainCombinedModelOneStep,
534+ trainDiscriminatorOneStep
478535} ;
0 commit comments