@@ -39,10 +39,14 @@ const METADATA_TEMPLATE_URL =
39
39
* that exceed this limit will be marked as `OOV_INDEX`.
40
40
* @param {string } maxLen Length of each sequence. Longer sequences will be
41
41
* pre-truncated; shorter ones will be pre-padded.
42
- * @return {tf.Tensor } The dataset represented as a 2D `tf.Tensor` of shape
43
- * `[]` and dtype `int32` .
42
+ * @param {string } multihot Whether to use multi-hot encoding of the words.
43
+ * Default: `false`.
44
+ * @return {tf.Tensor } If `multihot` is `false` (default), the dataset
45
+ * represented as a 2D `tf.Tensor` of shape `[numExamples, maxLen]` and
46
+ * dtype `int32`. Else, the dataset represented as a 2D `tf.Tensor` of
47
+ * shape `[numExamples, numWords]` and dtype `float32`.
44
48
*/
45
- function loadFeatures ( filePath , numWords , maxLen ) {
49
+ function loadFeatures ( filePath , numWords , maxLen , multihot = false ) {
46
50
const buffer = fs . readFileSync ( filePath ) ;
47
51
const numBytes = buffer . byteLength ;
48
52
@@ -67,10 +71,39 @@ function loadFeatures(filePath, numWords, maxLen) {
67
71
if ( seq . length > 0 ) {
68
72
sequences . push ( seq ) ;
69
73
}
70
- const paddedSequences =
71
- padSequences ( sequences , maxLen , 'pre' , 'pre' ) ;
72
- return tf . tensor2d (
73
- paddedSequences , [ paddedSequences . length , maxLen ] , 'int32' ) ;
74
+
75
+ // Get some sequence length stats.
76
+ let minLength = Infinity ;
77
+ let maxLength = - Infinity ;
78
+ sequences . forEach ( seq => {
79
+ const length = seq . length ;
80
+ if ( length < minLength ) {
81
+ minLength = length ;
82
+ }
83
+ if ( length > maxLength ) {
84
+ maxLength = length ;
85
+ }
86
+ } ) ;
87
+ console . log ( `Sequence length: min = ${ minLength } ; max = ${ maxLength } ` ) ;
88
+
89
+ if ( multihot ) {
90
+ // If requested by the arg, encode the sequences as multi-hot
91
+ // vectors.
92
+ const buffer = tf . buffer ( [ sequences . length , numWords ] ) ;
93
+ sequences . forEach ( ( seq , i ) => {
94
+ seq . forEach ( wordIndex => {
95
+ if ( wordIndex !== OOV_CHAR ) {
96
+ buffer . set ( 1 , i , wordIndex ) ;
97
+ }
98
+ } ) ;
99
+ } ) ;
100
+ return buffer . toTensor ( ) ;
101
+ } else {
102
+ const paddedSequences =
103
+ padSequences ( sequences , maxLen , 'pre' , 'pre' ) ;
104
+ return tf . tensor2d (
105
+ paddedSequences , [ paddedSequences . length , maxLen ] , 'int32' ) ;
106
+ }
74
107
}
75
108
76
109
/**
@@ -84,10 +117,23 @@ function loadTargets(filePath) {
84
117
const buffer = fs . readFileSync ( filePath ) ;
85
118
const numBytes = buffer . byteLength ;
86
119
120
+ let numPositive = 0 ;
121
+ let numNegative = 0 ;
122
+
87
123
let ys = [ ] ;
88
124
for ( let i = 0 ; i < numBytes ; ++ i ) {
89
- ys . push ( buffer . readUInt8 ( i ) ) ;
125
+ const y = buffer . readUInt8 ( i ) ;
126
+ if ( y === 1 ) {
127
+ numPositive ++ ;
128
+ } else {
129
+ numNegative ++ ;
130
+ }
131
+ ys . push ( y ) ;
90
132
}
133
+
134
+ console . log (
135
+ `Loaded ${ numPositive } positive examples and ` +
136
+ `${ numNegative } negative examples.` ) ;
91
137
return tf . tensor2d ( ys , [ ys . length , 1 ] , 'float32' ) ;
92
138
}
93
139
@@ -171,13 +217,13 @@ async function maybeDownloadAndExtract() {
171
217
* xTest: The same as `xTrain`, but for the test dataset.
172
218
* yTest: The same as `yTrain`, but for the test dataset.
173
219
*/
174
- export async function loadData ( numWords , len ) {
220
+ export async function loadData ( numWords , len , multihot = false ) {
175
221
const dataDir = await maybeDownloadAndExtract ( ) ;
176
222
177
223
const trainFeaturePath = path . join ( dataDir , 'imdb_train_data.bin' ) ;
178
- const xTrain = loadFeatures ( trainFeaturePath , numWords , len ) ;
224
+ const xTrain = loadFeatures ( trainFeaturePath , numWords , len , multihot ) ;
179
225
const testFeaturePath = path . join ( dataDir , 'imdb_test_data.bin' ) ;
180
- const xTest = loadFeatures ( testFeaturePath , numWords , len ) ;
226
+ const xTest = loadFeatures ( testFeaturePath , numWords , len , multihot ) ;
181
227
const trainTargetsPath = path . join ( dataDir , 'imdb_train_targets.bin' ) ;
182
228
const yTrain = loadTargets ( trainTargetsPath ) ;
183
229
const testTargetsPath = path . join ( dataDir , 'imdb_test_targets.bin' ) ;
0 commit comments