Skip to content

Commit 18cc9ce

Browse files
authored
[sentiment] Support training multi-hot MLP and writing out embedding files (#225)
Also in this change: - Adjust some hyperparameters Example screenshot from using the new `--embeddingFilesPrefix` flag + the Embedding Projector: ![image](https://user-images.githubusercontent.com/16824702/52145038-f0fce480-262d-11e9-9313-9a5014ace25f.png)
1 parent 16cf488 commit 18cc9ce

File tree

4 files changed

+302
-61
lines changed

4 files changed

+302
-61
lines changed

sentiment/README.md

+33-1
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ yarn train <MODEL_TYPE>
3737
where `MODEL_TYPE` is a required argument that specifies what type of model is to be
3838
trained. The available options are:
3939

40+
- `multihot`: A model that takes a multi-hot encoding of the words in the sequence.
41+
In terms of data representation and model complexity, this is the simplest model
42+
in this example.
4043
- `flatten`: A model that flattens the embedding vectors of all words in the sequence.
41-
- `cnn`: A 1D convolutional model.
44+
- `cnn`: A 1D convolutional model, with a dropout layer included.
4245
- `simpleRNN`: A model that uses a SimpleRNN layer (`tf.layers.simpleRNN`)
4346
- `lstm`: A model that uses a LSTM laayer (`tf.layers.lstm`)
4447
- `bidirectionalLSTM`: A model that uses a bidirectional LSTM layer
@@ -65,5 +68,34 @@ Other arguments of the `yarn train` command include:
6568
- `--epochs`, `--batchSize`, and `--validationSplit` are training-related settings.
6669
- `--modelSavePath` allows you to specify where to store the model and metadata after
6770
training completes.
71+
- `--embeddingFilesPrefix` Prefix for the path to which to save the embedding vectors
72+
and labels files (optinal). See the section below for details.
6873

6974
The detailed code for training are in the file [train.js](./train.js).
75+
76+
### Visualizing the word embeddings in embedding projector
77+
78+
If you train a word embedding-based model (e.g., `cnn` or `lstm`), you can let the
79+
`yarn train` script write the embedding vectors, together with the corresponding
80+
word labels, to files after the model training completes. This is done using the
81+
``--embeddingFilesPrefix`, e.g.,
82+
83+
```sh
84+
yarn train --maxLen 500 cnn --epochs 2 --embeddingFilesPrefix /tmp/imdb_embed
85+
```
86+
87+
The above command will generate two files:
88+
89+
- `/tmp/imdb_embed_vectors.tsv`: A tab-separated-values file that for the numeric
90+
values of the word embeddings. Each line contains the embedding vector from a
91+
word.
92+
- `/tmp/imdb_embed_labels.tsv`: A file consisting of the word labels that correspond
93+
to the vectors in the previous file. Each line is a word.
94+
95+
These files can be directly uploaded to the Embedding Projector
96+
(https://projector.tensorflow.org/) for visualization using the
97+
[T-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) or
98+
[PCA](https://en.wikipedia.org/wiki/Principal_component_analysis) algorithm
99+
100+
See example screenshot:
101+
![image](https://user-images.githubusercontent.com/16824702/52145038-f0fce480-262d-11e9-9313-9a5014ace25f.png)

sentiment/data.js

+57-11
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ const METADATA_TEMPLATE_URL =
3939
* that exceed this limit will be marked as `OOV_INDEX`.
4040
* @param {string} maxLen Length of each sequence. Longer sequences will be
4141
* pre-truncated; shorter ones will be pre-padded.
42-
* @return {tf.Tensor} The dataset represented as a 2D `tf.Tensor` of shape
43-
* `[]` and dtype `int32` .
42+
* @param {string} multihot Whether to use multi-hot encoding of the words.
43+
* Default: `false`.
44+
* @return {tf.Tensor} If `multihot` is `false` (default), the dataset
45+
* represented as a 2D `tf.Tensor` of shape `[numExamples, maxLen]` and
46+
* dtype `int32`. Else, the dataset represented as a 2D `tf.Tensor` of
47+
* shape `[numExamples, numWords]` and dtype `float32`.
4448
*/
45-
function loadFeatures(filePath, numWords, maxLen) {
49+
function loadFeatures(filePath, numWords, maxLen, multihot = false) {
4650
const buffer = fs.readFileSync(filePath);
4751
const numBytes = buffer.byteLength;
4852

@@ -67,10 +71,39 @@ function loadFeatures(filePath, numWords, maxLen) {
6771
if (seq.length > 0) {
6872
sequences.push(seq);
6973
}
70-
const paddedSequences =
71-
padSequences(sequences, maxLen, 'pre', 'pre');
72-
return tf.tensor2d(
73-
paddedSequences, [paddedSequences.length, maxLen], 'int32');
74+
75+
// Get some sequence length stats.
76+
let minLength = Infinity;
77+
let maxLength = -Infinity;
78+
sequences.forEach(seq => {
79+
const length = seq.length;
80+
if (length < minLength) {
81+
minLength = length;
82+
}
83+
if (length > maxLength) {
84+
maxLength = length;
85+
}
86+
});
87+
console.log(`Sequence length: min = ${minLength}; max = ${maxLength}`);
88+
89+
if (multihot) {
90+
// If requested by the arg, encode the sequences as multi-hot
91+
// vectors.
92+
const buffer = tf.buffer([sequences.length, numWords]);
93+
sequences.forEach((seq, i) => {
94+
seq.forEach(wordIndex => {
95+
if (wordIndex !== OOV_CHAR) {
96+
buffer.set(1, i, wordIndex);
97+
}
98+
});
99+
});
100+
return buffer.toTensor();
101+
} else {
102+
const paddedSequences =
103+
padSequences(sequences, maxLen, 'pre', 'pre');
104+
return tf.tensor2d(
105+
paddedSequences, [paddedSequences.length, maxLen], 'int32');
106+
}
74107
}
75108

76109
/**
@@ -84,10 +117,23 @@ function loadTargets(filePath) {
84117
const buffer = fs.readFileSync(filePath);
85118
const numBytes = buffer.byteLength;
86119

120+
let numPositive = 0;
121+
let numNegative = 0;
122+
87123
let ys = [];
88124
for (let i = 0; i < numBytes; ++i) {
89-
ys.push(buffer.readUInt8(i));
125+
const y = buffer.readUInt8(i);
126+
if (y === 1) {
127+
numPositive++;
128+
} else {
129+
numNegative++;
130+
}
131+
ys.push(y);
90132
}
133+
134+
console.log(
135+
`Loaded ${numPositive} positive examples and ` +
136+
`${numNegative} negative examples.`);
91137
return tf.tensor2d(ys, [ys.length, 1], 'float32');
92138
}
93139

@@ -171,13 +217,13 @@ async function maybeDownloadAndExtract() {
171217
* xTest: The same as `xTrain`, but for the test dataset.
172218
* yTest: The same as `yTrain`, but for the test dataset.
173219
*/
174-
export async function loadData(numWords, len) {
220+
export async function loadData(numWords, len, multihot = false) {
175221
const dataDir = await maybeDownloadAndExtract();
176222

177223
const trainFeaturePath = path.join(dataDir, 'imdb_train_data.bin');
178-
const xTrain = loadFeatures(trainFeaturePath, numWords, len);
224+
const xTrain = loadFeatures(trainFeaturePath, numWords, len, multihot);
179225
const testFeaturePath = path.join(dataDir, 'imdb_test_data.bin');
180-
const xTest = loadFeatures(testFeaturePath, numWords, len);
226+
const xTest = loadFeatures(testFeaturePath, numWords, len, multihot);
181227
const trainTargetsPath = path.join(dataDir, 'imdb_train_targets.bin');
182228
const yTrain = loadTargets(trainTargetsPath);
183229
const testTargetsPath = path.join(dataDir, 'imdb_test_targets.bin');

sentiment/embedding.js

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Utilites for extracting the embedding matrix and output them as files.
20+
*/
21+
22+
import {writeFileSync} from 'fs';
23+
import * as tf from '@tensorflow/tfjs';
24+
25+
/**
26+
* Extract the first embedding matrix from a TensorFlow.js model.
27+
*
28+
* @param {tf.model} model An instance of tf.Model, assumed to contain an
29+
* Embedding layer.
30+
* @retuns {tf.Tensor} The embedding matrix from the first Embedding
31+
* layer encoutnered while iterating through all layers of the model.
32+
* @throws Error if no embedding layer can be found in the model.
33+
*/
34+
function extractEmbeddingMatrix(model) {
35+
for (const layer of model.layers) {
36+
if (layer.getClassName() === 'Embedding') {
37+
const embed = layer.getWeights()[0];
38+
tf.util.assert(
39+
embed.rank === 2,
40+
`Expected the rank of an embedding matrix to be 2, ` +
41+
`but got ${embed.rank}`);
42+
return embed;
43+
}
44+
}
45+
throw new Error('Cannot find any Embedding layer in model.');
46+
}
47+
48+
/**
49+
* Write the values of the first embedding matrix of a model to files.
50+
*
51+
* The word labels are writen as well. The vectors and labels files are
52+
* directly loadable into the Embedding Projector
53+
* (https://projector.tensorflow.org/).
54+
*
55+
* @param {tf.model} model An instance of tf.Model, assumed to contain an
56+
* Embedding layer.
57+
* @param {string} prefix Path prefix for writing the vectors and labels files.
58+
* For exapmle if `prefix` is `/tmp/embed`, then
59+
* - the vectors will be written to `/tmp/embed_vectors.tsv`
60+
* - the labels will be written to `/tmp/embed_labels.tsv`
61+
* @param {{[word: string]: number}} wordIndex A dictionary mapping words to
62+
* their integer indices.
63+
* @param {number} indexFrom The basevalue of the integer indices.
64+
*/
65+
export async function writeEmbeddingMatrixAndLabels(
66+
model, prefix, wordIndex, indexFrom) {
67+
tf.util.assert(
68+
prefix != null && prefix.length > 0,
69+
`Null, undefined or empty path prefix`);
70+
71+
const embed = extractEmbeddingMatrix(model);
72+
73+
const numWords = embed.shape[0];
74+
const embedDims = embed.shape[1];
75+
const embedData = await embed.data();
76+
77+
// Write the ebmedding matrix to file.
78+
let vectorsStr = '';
79+
let index = 0;
80+
for (let i = 0; i < numWords; ++i) {
81+
for (let j = 0; j < embedDims; ++j) {
82+
vectorsStr += embedData[index++].toFixed(5);
83+
if (j < embedDims - 1) {
84+
vectorsStr += '\t';
85+
} else {
86+
vectorsStr += '\n';
87+
}
88+
}
89+
}
90+
91+
const vectorsFilePath = `${prefix}_vectors.tsv`;
92+
writeFileSync(vectorsFilePath, vectorsStr, {encoding: 'utf-8'});
93+
console.log(
94+
`Written embedding vectors (${numWords} * ${embedDims}) to: ` +
95+
`${vectorsFilePath}`);
96+
97+
// Collect and write the word labels.
98+
const indexToWord = {};
99+
for (const word in wordIndex) {
100+
indexToWord[wordIndex[word]] = word;
101+
}
102+
103+
let labelsStr = '';
104+
for(let i = 0; i < numWords; ++i) {
105+
if (i >= indexFrom) {
106+
labelsStr += indexToWord[i - indexFrom];
107+
} else {
108+
labelsStr += 'not-a-word';
109+
}
110+
labelsStr += '\n';
111+
}
112+
113+
const labelsFilePath = `${prefix}_labels.tsv`;
114+
writeFileSync(labelsFilePath, labelsStr, {encoding: 'utf-8'});
115+
console.log(
116+
`Written embedding labels (${numWords}) to: ` +
117+
`${labelsFilePath}`);
118+
}

0 commit comments

Comments
 (0)