Skip to content

Commit f979101

Browse files
authored
[lstm-text-generation] Add unit tests and node.js training script (#231)
- Refactoring to support training and inference in both browser and node.js - Add unit tests for model.ts and data.ts - Add command `yarn train` to support training in Node.js - Supports model saving (`--savePath` flag) - Displays randomly sampled text during training, after each epoch of training - Add command `yarn gen` to support text generation in Node.js, based on model files saved from `yarn train` - Update README.md
1 parent ee26ba0 commit f979101

15 files changed

+1383
-197
lines changed

Diff for: lstm-text-generation/.babelrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
]
1414
],
1515
"plugins": [
16-
"@babel/plugin-transform-runtime"
16+
"transform-runtime"
1717
]
1818
}

Diff for: lstm-text-generation/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.txt

Diff for: lstm-text-generation/README.md

+59
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# TensorFlow.js Example: Train LSTM to Generate Text
2+
3+
[See this example live!](https://storage.googleapis.com/tfjs-examples/lstm-text-generation/dist/index.html)
24

35
## Overview
46

@@ -35,6 +37,63 @@ https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py
3537

3638
## Usage
3739

40+
### Running the Web Demo
41+
42+
The web demo supports model training and text generation. To launch the demo, do:
43+
3844
```sh
3945
yarn && yarn watch
4046
```
47+
48+
### Training Models in Node.js
49+
50+
Training a model in Node.js should give you a faster performance than the browser
51+
environment.
52+
53+
To start a training job, enter command lines such as:
54+
55+
```sh
56+
yarn
57+
yarn train shakespeare \
58+
--lstmLayerSize 128,128 \
59+
--epochs 120 \
60+
--savePath ./my-shakespeare-model
61+
```
62+
63+
- The first argument to `yarn train` (`shakespeare`) specifies what text corpus
64+
to train the model on. See the console output of `yarn train --help` for a set
65+
of supported text data.
66+
- The argument `--lstmLayerSize 128,128` specifies that the next-character
67+
prediction model should contain two LSTM layers stacked on top of each other,
68+
each with 128 units.
69+
- The flag `--epochs` is used to specify the number of training epochs.
70+
- The argument `--savePath ...` lets the training script save the model at the
71+
specified path once the training completes
72+
73+
If you have a CUDA-enabled GPU set up properly on your system, you can
74+
add the `--gpu` flag to the command line to train the model on the GPU, which
75+
should give you a further performance boost.
76+
77+
### Generating Text in Node.js using Saved Model Files
78+
79+
The example command line above generates a set of model files in the
80+
`./my-shakespeare-model` folder after the completion of the training. You can
81+
load the model and use it to generate text. For example:
82+
83+
```sh
84+
yarn gen shakespeare ./my-shakespeare-model/model.json \
85+
--genLength 250 \
86+
--temperature 0.6
87+
```
88+
89+
The command will randomly sample a snippet of text from the shakespeare
90+
text corpus and use it as the seed to generate text.
91+
92+
- The first argument (`shakespeare`) specifies the text corpus.
93+
- The second argument specifies the path to the saved JSON file for the
94+
model, which has been generated in the previous section.
95+
- The `--genLength` flag allows you to speicify how many characters
96+
to generate.
97+
- The `--temperature` flag allows you to specify the stochacity (randomness)
98+
of the generation processs. It should be a number greater than or equal to
99+
zero. The higher the value is, the more random the generated text will be.

Diff for: lstm-text-generation/data.js

+61-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@
1717

1818
import * as tf from '@tensorflow/tfjs';
1919

20+
// TODO(cais): Support user-supplied text data.
21+
export const TEXT_DATA_URLS = {
22+
'nietzsche': {
23+
url:
24+
'https://storage.googleapis.com/tfjs-examples/lstm-text-generation/data/nietzsche.txt',
25+
needle: 'Nietzsche'
26+
},
27+
'julesverne': {
28+
url:
29+
'https://storage.googleapis.com/tfjs-examples/lstm-text-generation/data/t1.verne.txt',
30+
needle: 'Jules Verne'
31+
},
32+
'shakespeare': {
33+
url:
34+
'https://storage.googleapis.com/tfjs-examples/lstm-text-generation/data/t8.shakespeare.txt',
35+
needle: 'Shakespeare'
36+
},
37+
'tfjs-code': {
38+
url: 'https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.js',
39+
needle: 'TensorFlow.js Code (Compiled, 0.11.7)'
40+
}
41+
}
42+
2043
/**
2144
* A class for text data.
2245
*
@@ -38,6 +61,13 @@ export class TextData {
3861
* example of the training data (in `textString`) to the next.
3962
*/
4063
constructor(dataIdentifier, textString, sampleLen, sampleStep) {
64+
tf.util.assert(
65+
sampleLen > 0,
66+
`Expected sampleLen to be a positive integer, but got ${sampleLen}`);
67+
tf.util.assert(
68+
sampleStep > 0,
69+
`Expected sampleStep to be a positive integer, but got ${sampleStep}`);
70+
4171
if (!dataIdentifier) {
4272
throw new Error('Model identifier is not provided.');
4373
}
@@ -51,7 +81,6 @@ export class TextData {
5181

5282
this.getCharSet_();
5383
this.convertAllTextToIndices_();
54-
this.generateExampleBeginIndices_();
5584
}
5685

5786
/**
@@ -98,6 +127,12 @@ export class TextData {
98127
* `ys` has the shape of `[numExamples, this.charSetSize]`.
99128
*/
100129
nextDataEpoch(numExamples) {
130+
this.generateExampleBeginIndices_();
131+
132+
if (numExamples == null) {
133+
numExamples = this.exampleBeginIndices_.length;
134+
}
135+
101136
const xsBuffer = new tf.TensorBuffer([
102137
numExamples, this.sampleLen_, this.charSetSize_]);
103138
const ysBuffer = new tf.TensorBuffer([numExamples, this.charSetSize_]);
@@ -199,3 +234,28 @@ export class TextData {
199234
this.examplePosition_ = 0;
200235
}
201236
}
237+
238+
/**
239+
* Get a file by downloading it if necessary.
240+
*
241+
* @param {string} sourceURL URL to download the file from.
242+
* @param {string} destPath Destination file path on local filesystem.
243+
*/
244+
export async function maybeDownload(sourceURL, destPath) {
245+
const fs = require('fs');
246+
return new Promise(async (resolve, reject) => {
247+
if (!fs.existsSync(destPath) || fs.lstatSync(destPath).size === 0) {
248+
const localZipFile = fs.createWriteStream(destPath);
249+
console.log(`Downloading file from ${sourceURL} to ${destPath}...`);
250+
https.get(sourceURL, response => {
251+
response.pipe(localZipFile);
252+
localZipFile.on('finish', () => {
253+
localZipFile.close(() => resolve());
254+
});
255+
localZipFile.on('error', err => reject(err));
256+
});
257+
} else {
258+
return resolve();
259+
}
260+
});
261+
}

Diff for: lstm-text-generation/data_test.js

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {TextData} from './data';
19+
20+
// tslint:disable:max-line-length
21+
const FAKE_TEXT = `Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse tempor aliquet justo non varius. Curabitur eget convallis velit. Vivamus malesuada, tortor ut finibus posuere, libero lacus eleifend felis, sit amet tempus dolor magna id nibh. Praesent non turpis libero. Praesent luctus, neque vitae suscipit suscipit, arcu neque aliquam justo, eget gravida diam augue nec lorem. Etiam scelerisque vel nibh sit amet maximus. Praesent et dui quis elit bibendum elementum a eget velit. Mauris porta lorem ac porttitor congue. Vestibulum lobortis ultrices velit, vitae condimentum elit ultrices a. Vivamus rutrum ultrices eros ac finibus. Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Morbi a purus a nibh eleifend convallis. Praesent non turpis volutpat, imperdiet lacus in, cursus tellus. Etiam elit velit, ornare sit amet nulla vel, aliquam iaculis mauris.
22+
23+
Phasellus sed sem ut justo sollicitudin cursus at sed neque. Proin tempor finibus nisl, nec aliquam leo porta at. Nullam vel mauris et neque pellentesque laoreet sit amet eu risus. Sed sed ante sed enim hendrerit commodo. Etiam blandit aliquet molestie. Nullam dictum imperdiet enim, quis scelerisque nunc ultricies sit amet. Praesent dictum dictum lobortis. Sed ut ipsum at orci commodo congue.
24+
25+
Aenean pharetra mollis erat, id convallis ante elementum at. Cras semper turpis nec lorem tempus ultrices. Sed eget purus vel est blandit dictum. Praesent auctor, sapien non consequat pellentesque, risus orci sagittis leo, at cursus nibh nisi vel quam. Morbi et orci id quam dictum efficitur ac iaculis nisl. Donec at nunc et nibh accumsan malesuada eu in odio. Donec quis elementum turpis. Vestibulum pretium rhoncus orci, nec gravida nisl hendrerit pellentesque. Cras imperdiet odio a quam mollis, in aliquet neque efficitur. Praesent at tincidunt ipsum. Maecenas neque risus, pretium ut orci sit amet, dignissim auctor dui. Sed finibus nunc elit, rhoncus ornare dui pharetra vitae. Sed ut iaculis ex. Quisque quis molestie ligula. Vivamus egestas rhoncus mollis.
26+
27+
Pellentesque volutpat ipsum vitae ex interdum, eu rhoncus dolor fringilla. Suspendisse potenti. Maecenas in sem leo. Curabitur vestibulum porta vulputate. Nunc quis consectetur enim. Aliquam congue, augue in commodo porttitor, sem tellus posuere augue, ut aliquam sapien massa in est. Duis convallis pellentesque vehicula. Mauris ipsum urna, congue consequat posuere sed, euismod nec mauris. Praesent sollicitudin scelerisque scelerisque. Ut commodo nisl vitae nunc feugiat auctor. Praesent imperdiet magna facilisis nunc vulputate, vel suscipit leo consequat. Duis fermentum rutrum ipsum a laoreet. Nunc dictum libero in quam pellentesque, sit amet tempus tellus suscipit. Curabitur pharetra erat bibendum malesuada rhoncus.
28+
29+
Donec laoreet leo ligula, ut condimentum mi placerat ut. Sed pretium sollicitudin nisl quis tincidunt. Proin id nisl ornare, interdum lorem quis, posuere lacus. Cras cursus mollis scelerisque. Mauris mattis mi sed orci feugiat, et blandit velit tincidunt. Donec ultrices leo vel tellus tincidunt, id vehicula mi commodo. Nulla egestas mollis massa. Etiam blandit nisl eu risus luctus viverra. Mauris eget mi sem.
30+
31+
`;
32+
// tslint:enable:max-line-length
33+
34+
describe('TextData', () => {
35+
it('Creation', () => {
36+
const data = new TextData('LoremIpsum', FAKE_TEXT, 20, 3);
37+
expect(data.sampleLen()).toEqual(20);
38+
expect(data.charSetSize()).toBeGreaterThan(0);
39+
});
40+
41+
it('nextDataEpoch: full pass', () => {
42+
const data = new TextData('LoremIpsum', FAKE_TEXT, 20, 3);
43+
const [xs, ys] = data.nextDataEpoch();
44+
expect(xs.rank).toEqual(3);
45+
expect(ys.rank).toEqual(2);
46+
expect(xs.shape[0]).toEqual(ys.shape[0]);
47+
expect(xs.shape[1]).toEqual(20);
48+
expect(xs.shape[2]).toEqual(ys.shape[1]);
49+
});
50+
51+
it('nextDataEpoch: partial pass', () => {
52+
const data = new TextData('LoremIpsum', FAKE_TEXT, 20, 3);
53+
const [xs, ys] = data.nextDataEpoch(4);
54+
expect(xs.rank).toEqual(3);
55+
expect(ys.rank).toEqual(2);
56+
expect(xs.shape[0]).toEqual(4);
57+
expect(ys.shape[0]).toEqual(4);
58+
expect(xs.shape[1]).toEqual(20);
59+
expect(xs.shape[2]).toEqual(ys.shape[1]);
60+
});
61+
62+
it('getFromCharSet', () => {
63+
const data = new TextData('LoremIpsum', FAKE_TEXT, 20, 3);
64+
const charSetSize = data.charSetSize();
65+
expect(data.getFromCharSet(0)).not.toEqual(data.getFromCharSet(1));
66+
expect(data.getFromCharSet(0))
67+
.not.toEqual(data.getFromCharSet(charSetSize - 1));
68+
expect(data.getFromCharSet(charSetSize)).toBeUndefined();
69+
expect(data.getFromCharSet(-1)).toBeUndefined();
70+
});
71+
72+
it('getRandomSlice', () => {
73+
const data = new TextData('LoremIpsum', FAKE_TEXT, 20, 3);
74+
const [text, indices] = data.getRandomSlice();
75+
expect(typeof text).toEqual('string');
76+
expect(Array.isArray(indices)).toEqual(true);
77+
});
78+
});
79+

Diff for: lstm-text-generation/gen_node.js

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Use a trained next-character prediction model to generate some text.
20+
*/
21+
22+
import * as fs from 'fs';
23+
import * as os from 'os';
24+
import * as path from 'path';
25+
import * as argparse from 'argparse';
26+
27+
import * as tf from '@tensorflow/tfjs';
28+
29+
import {maybeDownload, TextData, TEXT_DATA_URLS} from './data';
30+
import {generateText} from './model';
31+
32+
function parseArgs() {
33+
const parser = argparse.ArgumentParser({
34+
description: 'Train an lstm-text-generation model.'
35+
});
36+
parser.addArgument('textDatasetName', {
37+
type: 'string',
38+
choices: Object.keys(TEXT_DATA_URLS),
39+
help: 'Name of the text dataset'
40+
});
41+
parser.addArgument('modelJSONPath', {
42+
type: 'string',
43+
help: 'Path to the trained next-char prediction model saved on disk ' +
44+
'(e.g., ./my-model/model.json)'
45+
});
46+
parser.addArgument('--genLength', {
47+
type: 'int',
48+
defaultValue: 200,
49+
help: 'Length of the text to generate.'
50+
});
51+
parser.addArgument('--temperature', {
52+
type: 'float',
53+
defaultValue: 0.5,
54+
help: 'Temperature value to use for text generation. Higher values ' +
55+
'lead to more random-looking generation results.'
56+
});
57+
parser.addArgument('--gpu', {
58+
action: 'storeTrue',
59+
help: 'Use CUDA GPU for training.'
60+
});
61+
parser.addArgument('--sampleStep', {
62+
type: 'int',
63+
defaultValue: 3,
64+
help: 'Step length: how many characters to skip between one example ' +
65+
'extracted from the text data to the next.'
66+
});
67+
return parser.parseArgs();
68+
}
69+
70+
async function main() {
71+
const args = parseArgs();
72+
73+
if (args.gpu) {
74+
console.log('Using GPU');
75+
require('@tensorflow/tfjs-node-gpu');
76+
} else {
77+
console.log('Using CPU');
78+
require('@tensorflow/tfjs-node');
79+
}
80+
81+
// Load the model.
82+
const model = await tf.loadModel(`file://${args.modelJSONPath}`);
83+
84+
const sampleLen = model.inputs[0].shape[1];
85+
86+
// Create the text data object.
87+
const textDataURL = TEXT_DATA_URLS[args.textDatasetName].url;
88+
const localTextDataPath = path.join(os.tmpdir(), path.basename(textDataURL));
89+
await maybeDownload(textDataURL, localTextDataPath);
90+
const text = fs.readFileSync(localTextDataPath, {encoding: 'utf-8'});
91+
const textData = new TextData('text-data', text, sampleLen, args.sampleStep);
92+
93+
// Get a seed text from the text data object.
94+
const [seed, seedIndices] = textData.getRandomSlice();
95+
96+
console.log(`Seed text:\n"${seed}"\n`);
97+
98+
const generated = await generateText(
99+
model, textData, seedIndices, args.genLength, args.temperature);
100+
101+
console.log(`Generated text:\n"${generated}"\n`);
102+
}
103+
104+
main();

0 commit comments

Comments
 (0)