Skip to content

Commit 16cf488

Browse files
authored
[mnist-acgan] Improve command line; add tensorboard and unit tests (#245)
- Allow easy switching to CUDA GPU training with the `--gpu` flag like in some of the newer examples. - Add the `--logDir` flag to allow logging the loss values of the generator and the discriminator to tensorboard. - Add unit tests: gan_test.js
1 parent abca94c commit 16cf488

File tree

7 files changed

+345
-56
lines changed

7 files changed

+345
-56
lines changed

mnist-acgan/README.md

+62
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,61 @@ yarn
3838
yarn train
3939
```
4040

41+
If you have a CUDA-enabled GPU on your system, you can add the `--gpu` flag
42+
to train the model on the GPU, which should give you a significant boost in
43+
the speed of training:
44+
45+
```sh
46+
yarn
47+
yarn train --gpu
48+
```
49+
4150
The training job is a long running one and takes a few hours to complete on
4251
a GPU (using @tensorflow/tfjs-node-gpu) and even longer on a CPU
4352
(using @tensorflow/tfjs-node). It saves the generator part of the ACGAN
4453
into the `./dist/generator` folder at the beginning of the training and
4554
at the end of every training epoch. Some additional metadata is
4655
saved with the model as well.
56+
57+
### Monitoring GAN training using TensorBoard
58+
59+
The Node.js-based training script allows you to log the loss values from
60+
the generator and the discriminator to
61+
[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard).
62+
Relative to printing loss values to the console, which the
63+
training script performs by default, logging to tensorboard has the following
64+
advantanges:
65+
66+
1. Persistence of the loss values, so you can have a copy of the training
67+
history available even if the system crashes in the middle of the training
68+
for some reason, while logs in consoles a more ephemeral.
69+
2. Visualizing the loss values as curves makes the trends easier to see (e.g.,
70+
see the screenshot below).
71+
72+
![MNIST ACGAN Training: TensorBoard Example](./mnist-acgan-tensorboard-example.png)
73+
74+
To do this in this example, add the flag `--logDir` to the `yarn train`
75+
command, followed by the directory to which you want the logs to
76+
be written, e.g.,
77+
78+
```sh
79+
yarn train --gpu --logDir /tmp/mnist-acgan-logs
80+
```
81+
82+
Then install tensorboard and start it by pointing it to the log directory:
83+
84+
```sh
85+
# Skip this step if you have already installed tensorboard.
86+
pip install tensorboard
87+
88+
tensorboard --logdir /tmp/mnist-acgan-logs
89+
```
90+
91+
tensorboard will print an HTTP URL in the terminal. Open your browser and
92+
navigate to the URL to view the loss curves in the Scalar dashboard of
93+
TensorBoard.
94+
95+
### Running Generator demo in the Browser
4796

4897
To start the demo in the browser, do in a separate terminal:
4998

@@ -77,3 +126,16 @@ with
77126
```js
78127
require('@tensorflow/tfjs-node-gpu');
79128
```
129+
130+
## Running unit tests
131+
132+
This example comes with JavaScript unit tests. To run them, do:
133+
134+
```sh
135+
pushd ../ # Go to the root directory of tfjs-exapmles
136+
yarn
137+
popd # Go back to mnist-acgan/
138+
139+
yarn
140+
yarn test
141+
```

mnist-acgan/gan.js

+100-43
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
* yarn
2626
* yarn train
2727
* ```
28+
*
29+
* If available, a CUDA GPU will give you a higher training speed:
30+
*
31+
* ```sh
32+
* yarn
33+
* yarn train --gpu
34+
* ```
2835
*
2936
* To start the demo in the browser, do in a separate terminal:
3037
*
@@ -50,13 +57,6 @@ const fs = require('fs');
5057
const path = require('path');
5158

5259
const argparse = require('argparse');
53-
const tf = require('@tensorflow/tfjs');
54-
require('@tensorflow/tfjs-node');
55-
56-
// Uncomment me to train the model on GPU.
57-
// Requires: CUDA-enabled GPU, installs of CUDA toolkit and CuDNN.
58-
// require('@tensorflow/tfjs-node-gpu');
59-
6060
const data = require('./data');
6161

6262
// Number of classes in the MNIST dataset.
@@ -65,6 +65,11 @@ const NUM_CLASSES = 10;
6565
// MNIST image size.
6666
const IMAGE_SIZE = 28;
6767

68+
// The value of the tf object will be set dynamically, depending on whether
69+
// the CPU (tfjs-node) or GPU (tfjs-node-gpu) backend is used. This is why
70+
// `let` is used in lieu of the more conventiona `const` here.
71+
let tf = require('@tensorflow/tfjs');
72+
6873
/**
6974
* Build the generator part of ACGAN.
7075
*
@@ -77,7 +82,7 @@ const IMAGE_SIZE = 28;
7782
* It generates one output: the generated (i.e., fake) image.
7883
*
7984
* @param {number} latentSize Size of the latent space.
80-
* @returns {tf.Model} The generator model.
85+
* @returns {tf.LayersModel} The generator model.
8186
*/
8287
function buildGenerator(latentSize) {
8388
tf.util.assert(
@@ -171,7 +176,7 @@ function buildGenerator(latentSize) {
171176
* which is the discriminator's 10-class classification result
172177
* for the input image.
173178
*
174-
* @returns {tf.Model} The discriminator model.
179+
* @returns {tf.LayersModel} The discriminator model.
175180
*/
176181
function buildDiscriminator() {
177182
const cnn = tf.sequential();
@@ -224,9 +229,43 @@ function buildDiscriminator() {
224229
return tf.model({inputs: image, outputs: [realnessScore, aux]});
225230
}
226231

232+
/**
233+
* Build a combined ACGAN model.
234+
*
235+
* @param {number} latentSize Size of the latent vector.
236+
* @param {tf.SymbolicTensor} imageClass Symbolic tensor for the desired image
237+
* class. This is the other input to the generator.
238+
* @param {tf.LayersModel} generator The generator.
239+
* @param {tf.LayersModel} discriminator The discriminator.
240+
* @param {tf.Optimizer} optimizer The optimizer to be used for training the
241+
* combined model.
242+
* @returns {tf.LayersModel} The combined ACGAN model, compiled.
243+
*/
244+
function buildCombinedModel(latentSize, generator, discriminator, optimizer) {
245+
// Latent vector. This is one of the two inputs to the generator.
246+
const latent = tf.input({shape: [latentSize]});
247+
// Desired image class. This is the second input to the generator.
248+
const imageClass = tf.input({shape: [1]});
249+
// Get the symbolic tensor for fake images generated by the generator.
250+
let fake = generator.apply([latent, imageClass]);
251+
let aux;
252+
253+
// We only want to be able to train generation for the combined model.
254+
discriminator.trainable = false;
255+
[fake, aux] = discriminator.apply(fake);
256+
const combined =
257+
tf.model({inputs: [latent, imageClass], outputs: [fake, aux]});
258+
combined.compile({
259+
optimizer,
260+
loss: ['binaryCrossentropy', 'sparseCategoricalCrossentropy']
261+
});
262+
combined.summary();
263+
return combined;
264+
}
265+
227266
// "Soft" one used for training the combined ACGAN model.
228267
// This is an important trick in training GANs.
229-
const softOne = tf.scalar(0.95);
268+
const SOFT_ONE = 0.95;
230269

231270
/**
232271
* Train the discriminator for one step.
@@ -252,9 +291,9 @@ const softOne = tf.scalar(0.95);
252291
* @param {number} batchSize Size of the batch to draw from `xTrain` and
253292
* `yTrain`.
254293
* @param {number} latentSize Size of the latent space (z-space).
255-
* @param {tf.Model} generator The generator of the ACGAN.
256-
* @param {tf.Model} discriminator The discriminator of the ACGAN.
257-
* @returns The loss values from the one-step training as numbers.
294+
* @param {tf.LayersModel} generator The generator of the ACGAN.
295+
* @param {tf.LayersModel} discriminator The discriminator of the ACGAN.
296+
* @returns {number[]} The loss values from the one-step training as numbers.
258297
*/
259298
async function trainDiscriminatorOneStep(
260299
xTrain, yTrain, batchStart, batchSize, latentSize, generator,
@@ -275,9 +314,10 @@ async function trainDiscriminatorOneStep(
275314
generator.predict([zVectors, sampledLabels], {batchSize: batchSize});
276315

277316
const x = tf.concat([imageBatch, generatedImages], 0);
317+
278318
const y = tf.tidy(
279319
() => tf.concat(
280-
[tf.ones([batchSize, 1]).mul(softOne), tf.zeros([batchSize, 1])]));
320+
[tf.ones([batchSize, 1]).mul(SOFT_ONE), tf.zeros([batchSize, 1])]));
281321

282322
const auxY = tf.concat([labelBatch, sampledLabels], 0);
283323
return [x, y, auxY];
@@ -295,9 +335,9 @@ async function trainDiscriminatorOneStep(
295335
*
296336
* @param {number} batchSize Size of the fake-image batch to generate.
297337
* @param {number} latentSize Size of the latent space (z-space).
298-
* @param {tf.Model} combined The instance of tf.Model that combines
338+
* @param {tf.LayersModel} combined The instance of tf.LayersModel that combines
299339
* the generator and the discriminator.
300-
* @returns The loss values from the combined model as numbers.
340+
* @returns {number[]} The loss values from the combined model as numbers.
301341
*/
302342
async function trainCombinedModelOneStep(batchSize, latentSize, combined) {
303343
// TODO(cais): Remove tidy() once the current memory leak issue in tfjs-node
@@ -312,7 +352,7 @@ async function trainCombinedModelOneStep(batchSize, latentSize, combined) {
312352
// We want to train the generator to trick the discriminator.
313353
// For the generator, we want all the {fake, not-fake} labels to say
314354
// not-fake.
315-
const trick = tf.tidy(() => tf.ones([batchSize, 1]).mul(softOne));
355+
const trick = tf.tidy(() => tf.ones([batchSize, 1]).mul(SOFT_ONE));
316356
return [zVectors, sampledLabels, trick];
317357
});
318358

@@ -322,11 +362,15 @@ async function trainCombinedModelOneStep(batchSize, latentSize, combined) {
322362
return losses;
323363
}
324364

325-
function buildArgumentParser() {
365+
function parseArguments() {
326366
const parser = new argparse.ArgumentParser({
327367
description: 'TensorFlowj.js: MNIST ACGAN trainer example.',
328368
addHelp: true
329369
});
370+
parser.addArgument('--gpu', {
371+
action: 'storeTrue',
372+
help: 'Use tfjs-node-gpu for training (required CUDA GPU)'
373+
});
330374
parser.addArgument(
331375
'--epochs',
332376
{type: 'int', defaultValue: 100, help: 'Number of training epochs.'});
@@ -353,7 +397,11 @@ function buildArgumentParser() {
353397
defaultValue: './dist/generator',
354398
help: 'Path to which the generator model will be saved after every epoch.'
355399
});
356-
return parser;
400+
parser.addArgument('--logDir', {
401+
type: 'string',
402+
help: 'Optional log directory to which the loss values will be written.'
403+
});
404+
return parser.parseArgs();
357405
}
358406

359407
function makeMetadata(totalEpochs, currentEpoch, completed) {
@@ -366,8 +414,16 @@ function makeMetadata(totalEpochs, currentEpoch, completed) {
366414
}
367415

368416
async function run() {
369-
const parser = buildArgumentParser();
370-
const args = parser.parseArgs();
417+
const args = parseArguments();
418+
// Set the value of tf depending on whether the CPU or GPU version of
419+
// libtensorflow is used.
420+
if (args.gpu) {
421+
console.log('Using GPU');
422+
tf = require('@tensorflow/tfjs-node-gpu');
423+
} else {
424+
console.log('Using CPU');
425+
tf = require('@tensorflow/tfjs-node');
426+
}
371427

372428
if (!fs.existsSync(path.dirname(args.generatorSavePath))) {
373429
fs.mkdirSync(path.dirname(args.generatorSavePath));
@@ -387,23 +443,9 @@ async function run() {
387443
const generator = buildGenerator(args.latentSize);
388444
generator.summary();
389445

390-
const latent = tf.input({shape: [args.latentSize]});
391-
const imageClass = tf.input({shape: [1]});
392-
393-
// Get a fake image.
394-
let fake = generator.apply([latent, imageClass]);
395-
let aux;
396-
397-
// We only want to be able to train generation for the combined model.
398-
discriminator.trainable = false;
399-
[fake, aux] = discriminator.apply(fake);
400-
const combined =
401-
tf.model({inputs: [latent, imageClass], outputs: [fake, aux]});
402-
combined.compile({
403-
optimizer: tf.train.adam(args.learningRate, args.adamBeta1),
404-
loss: ['binaryCrossentropy', 'sparseCategoricalCrossentropy']
405-
});
406-
combined.summary();
446+
const optimizer = tf.train.adam(args.learningRate, args.adamBeta1);
447+
const combined = buildCombinedModel(
448+
args.latentSize, generator, discriminator, optimizer);
407449

408450
await data.loadData();
409451
let {images: xTrain, labels: yTrain} = data.getTrainData();
@@ -413,6 +455,13 @@ async function run() {
413455
await generator.save(saveURL);
414456

415457
let numTensors;
458+
let logWriter;
459+
if (args.logDir) {
460+
console.log(`Logging to tensorboard at logdir: ${args.logDir}`);
461+
logWriter = tf.node.summaryFileWriter(args.logDir);
462+
}
463+
464+
let step = 0;
416465
for (let epoch = 0; epoch < args.epochs; ++epoch) {
417466
// Write some metadata to disk at the beginning of every epoch.
418467
fs.writeFileSync(
@@ -442,7 +491,11 @@ async function run() {
442491
`epoch ${epoch + 1}/${args.epochs} batch ${batch + 1}/${
443492
numBatches}: ` +
444493
`dLoss = ${dLoss[0].toFixed(6)}, gLoss = ${gLoss[0].toFixed(6)}`);
445-
tf.dispose([dLoss, gLoss]);
494+
if (logWriter != null) {
495+
logWriter.scalar('dLoss', dLoss[0], step);
496+
logWriter.scalar('gLoss', gLoss[0], step);
497+
step++;
498+
}
446499

447500
// Assert on no memory leak.
448501
// TODO(cais): Remove this check once the current memory leak in
@@ -463,16 +516,20 @@ async function run() {
463516
console.log(`Saved generator model to: ${saveURL}\n`);
464517
}
465518

466-
// Write metadata to disk to indicate
467-
// the end of the training.
519+
// Write metadata to disk to indicate the end of the training.
468520
fs.writeFileSync(
469521
metadataPath,
470522
JSON.stringify(makeMetadata(args.epochs, args.epochs, true)));
471523
}
472524

473-
run();
525+
if (require.main === module) {
526+
run();
527+
}
474528

475529
module.exports = {
530+
buildCombinedModel,
476531
buildDiscriminator,
477-
buildGenerator
532+
buildGenerator,
533+
trainCombinedModelOneStep,
534+
trainDiscriminatorOneStep
478535
};

0 commit comments

Comments
 (0)