Skip to content

Commit 21ee908

Browse files
authored
[date-conversion-attention] Initial check-in of date-conversion-attention (#212)
- This PR checks in only the training scripts. Model inference in the browser will be checked in in a later PR. - Unit tests are written for the data, model training and inference routines, although they are not hooked up with Travis right now. They are just run manually instead.
1 parent 651fd08 commit 21ee908

14 files changed

+4587
-0
lines changed

date-conversion-attention/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

date-conversion-attention/README.md

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# TensorFlow.js Example: Date Conversion Through an LSTM-Attention Model
2+
3+
## Overview
4+
5+
This example shows how to use TensorFlow.js to train a model based on
6+
long short-term memory (LSTM) and the attention mechanism to achieve
7+
a task of converting various commonly seen date formats (e.g., 01/18/2019,
8+
18JAN2019, 18-01-2019) to the ISO date format (i.e., 2019-01-18).
9+
10+
We demonstrate the full machine-learning workflow, consisting of
11+
data engineering, server-side model training, client-side inference,
12+
model visualization, and unit testing in this example.
13+
14+
The training data is synthesized programmatically.
15+
16+
## Model training in Node.js
17+
18+
For efficiency, the training of the model happens outside the browser
19+
in Node.js, using tfjs-node or tfjs-node-gpu.
20+
21+
To run the training job, do
22+
23+
```sh
24+
yarn
25+
yarn train
26+
```
27+
28+
By default, the training uses tfjs-node, which runs on the CPU.
29+
If you have a CUDA-enabled GPU and have the CUDA and CuDNN libraries
30+
set up properly on your system, you can run the training on the GPU
31+
by:
32+
33+
```sh
34+
yarn
35+
yarn train --gpu
36+
```
37+
38+
## Using the model in the browser
39+
40+
TODO(cais): Implement it.
41+
42+
### Visualization of the attention mechanism
43+
44+
TODO(cais): Implement it.
45+
46+
## Running unit tests
47+
48+
The data and model code in this example are covered by unit tests.
49+
To run the unit tests:
50+
51+
```sh
52+
cd ../
53+
yarn
54+
cd date-conversion-attention
55+
yarn
56+
yarn test
57+
```
+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Date formats and conversion utility functions.
20+
*
21+
* This file is used for the training of the date-conversion model and
22+
* date conversions based on the trained model.
23+
*
24+
* It contains functions that generate random dates and represent them in
25+
* several different formats such as (2019-01-20 and 20JAN19).
26+
* It also contains functions that convert the text representation of
27+
* the dates into one-hot `tf.Tensor` representations.
28+
*/
29+
30+
const tf = require('@tensorflow/tfjs');
31+
32+
const MONTH_NAMES_FULL = [
33+
'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August',
34+
'September', 'October', 'November', 'December'
35+
];
36+
const MONTH_NAMES_3LETTER =
37+
MONTH_NAMES_FULL.map(name => name.slice(0, 3).toUpperCase());
38+
39+
const MIN_DATE = new Date('1950-01-01').getTime();
40+
const MAX_DATE = new Date('2050-01-01').getTime();
41+
42+
export const INPUT_LENGTH = 12 // Maximum length of all input formats.
43+
export const OUTPUT_LENGTH = 10 // Length of 'YYYY-MM-DD'.
44+
45+
// Use "\n" for padding for both input and output. It has to be at the
46+
// beginning so that `mask_zero=True` can be used in the keras model.
47+
export const INPUT_VOCAB = '\n0123456789/-., ' +
48+
MONTH_NAMES_3LETTER.join('')
49+
.split('')
50+
.filter(function(item, i, ar) {
51+
return ar.indexOf(item) === i;
52+
})
53+
.join('');
54+
55+
// OUTPUT_VOCAB includes an start-of-sequence (SOS) token, represented as
56+
// '\t'. Note that the date strings are represented in terms of their
57+
// constituent characters, not words or anything else.
58+
export const OUTPUT_VOCAB = '\n\t0123456789-';
59+
60+
export const START_CODE = 1;
61+
62+
/**
63+
* Generate a random date.
64+
*
65+
* @return {[number, number, number]} Year as an integer, month as an
66+
* integer >= 1 and <= 12, day as an integer >= 1.
67+
*/
68+
export function generateRandomDateTuple() {
69+
const date = new Date(Math.random() * (MAX_DATE - MIN_DATE) + MIN_DATE);
70+
return [date.getFullYear(), date.getMonth() + 1, date.getDate()];
71+
}
72+
73+
function toTwoDigitString(num) {
74+
return num < 10 ? `0${num}` : `${num}`;
75+
}
76+
77+
/** Date format such as 01202019. */
78+
export function dateTupleToDDMMMYYYY(dateTuple) {
79+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
80+
const dayStr = toTwoDigitString(dateTuple[2]);
81+
return `${dayStr}${monthStr}${dateTuple[0]}`;
82+
}
83+
84+
/** Date format such as 01/20/2019. */
85+
export function dateTupleToMMSlashDDSlashYYYY(dateTuple) {
86+
const monthStr = toTwoDigitString(dateTuple[1]);
87+
const dayStr = toTwoDigitString(dateTuple[2]);
88+
return `${monthStr}/${dayStr}/${dateTuple[0]}`;
89+
}
90+
91+
/** Date format such as 01/20/19. */
92+
export function dateTupleToMMSlashDDSlashYY(dateTuple) {
93+
const monthStr = toTwoDigitString(dateTuple[1]);
94+
const dayStr = toTwoDigitString(dateTuple[2]);
95+
const yearStr = `${dateTuple[0]}`.slice(2);
96+
return `${monthStr}/${dayStr}/${yearStr}`;
97+
}
98+
99+
/** Date format such as 012019. */
100+
export function dateTupleToMMDDYY(dateTuple) {
101+
const monthStr = toTwoDigitString(dateTuple[1]);
102+
const dayStr = toTwoDigitString(dateTuple[2]);
103+
const yearStr = `${dateTuple[0]}`.slice(2);
104+
return `${monthStr}${dayStr}${yearStr}`;
105+
}
106+
107+
/** Date format such as JAN 20 19. */
108+
export function dateTupleToMMMSpaceDDSpaceYY(dateTuple) {
109+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
110+
const dayStr = toTwoDigitString(dateTuple[2]);
111+
const yearStr = `${dateTuple[0]}`.slice(2);
112+
return `${monthStr} ${dayStr} ${yearStr}`;
113+
}
114+
115+
/** Date format such as JAN 20 2019. */
116+
export function dateTupleToMMMSpaceDDSpaceYYYY(dateTuple) {
117+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
118+
const dayStr = toTwoDigitString(dateTuple[2]);
119+
return `${monthStr} ${dayStr} ${dateTuple[0]}`;
120+
}
121+
122+
/** Date format such as JAN 20, 19. */
123+
export function dateTupleToMMMSpaceDDCommaSpaceYY(dateTuple) {
124+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
125+
const dayStr = toTwoDigitString(dateTuple[2]);
126+
const yearStr = `${dateTuple[0]}`.slice(2);
127+
return `${monthStr} ${dayStr}, ${yearStr}`;
128+
}
129+
130+
/** Date format such as JAN 20, 2019. */
131+
export function dateTupleToMMMSpaceDDCommaSpaceYYYY(dateTuple) {
132+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
133+
const dayStr = toTwoDigitString(dateTuple[2]);
134+
return `${monthStr} ${dayStr}, ${dateTuple[0]}`;
135+
}
136+
137+
/** Date format such as 20-01-2019. */
138+
export function dateTupleToDDDashMMDashYYYY(dateTuple) {
139+
const monthStr = toTwoDigitString(dateTuple[1]);
140+
const dayStr = toTwoDigitString(dateTuple[2]);
141+
return `${dayStr}-${monthStr}-${dateTuple[0]}`;
142+
}
143+
144+
/** Date format such as 20.01.2019. */
145+
export function dateTupleToDDDotMMDotYYYY(dateTuple) {
146+
const monthStr = toTwoDigitString(dateTuple[1]);
147+
const dayStr = toTwoDigitString(dateTuple[2]);
148+
return `${dayStr}.${monthStr}.${dateTuple[0]}`;
149+
}
150+
151+
/** Date format such as 2019.01.20. */
152+
export function dateTupleToYYYYDotMMDotDD(dateTuple) {
153+
const monthStr = toTwoDigitString(dateTuple[1]);
154+
const dayStr = toTwoDigitString(dateTuple[2]);
155+
return `${dateTuple[0]}.${monthStr}.${dayStr}`;
156+
}
157+
158+
159+
/** Date format such as 20190120 */
160+
export function dateTupleToYYYYMMDD(dateTuple) {
161+
const monthStr = toTwoDigitString(dateTuple[1]);
162+
const dayStr = toTwoDigitString(dateTuple[2]);
163+
return `${dateTuple[0]}${monthStr}${dayStr}`;
164+
}
165+
166+
/**
167+
* Date format such as 2019-01-20
168+
* (i.e., the ISO format and the conversion target).
169+
* */
170+
export function dateTupleToYYYYDashMMDashDD(dateTuple) {
171+
const monthStr = toTwoDigitString(dateTuple[1]);
172+
const dayStr = toTwoDigitString(dateTuple[2]);
173+
return `${dateTuple[0]}-${monthStr}-${dayStr}`;
174+
}
175+
176+
/**
177+
* Encode a number of input date strings as a `tf.Tensor`.
178+
*
179+
* The encoding is a sequence of one-hot vectors. The sequence is
180+
* padded at the end to the maximum possible length of any valid
181+
* input date strings. The padding value is zero.
182+
*
183+
* @param {string[]} dateStrings Input date strings. Each element of the array
184+
* must be one of the formats listed above. It is okay to mix multiple formats
185+
* in the array.
186+
* @returns {tf.Tensor} One-hot encoded characters as a `tf.Tensor`, of dtype
187+
* `float32` and shape `[numExamples, maxInputLength]`, where `maxInputLength`
188+
* is the maximum possible input length of all valid input date-string formats.
189+
*/
190+
export function encodeInputDateStrings(dateStrings) {
191+
const n = dateStrings.length;
192+
const x = tf.buffer([n, INPUT_LENGTH], 'float32');
193+
for (let i = 0; i < n; ++i) {
194+
for (let j = 0; j < INPUT_LENGTH; ++j) {
195+
if (j < dateStrings[i].length) {
196+
const char = dateStrings[i][j];
197+
const index = INPUT_VOCAB.indexOf(char);
198+
if (index === -1) {
199+
throw new Error(`Unknown char: ${char}`);
200+
}
201+
x.set(index, i, j);
202+
}
203+
}
204+
}
205+
return x.toTensor();
206+
}
207+
208+
/**
209+
* Encode a number of output date strings as a `tf.Tensor`.
210+
*
211+
* The encoding is a sequence of integer indices.
212+
*
213+
* @param {string[]} dateStrings An array of output date strings, must be in the
214+
* ISO date format (YYYY-MM-DD).
215+
* @returns {tf.Tensor} Integer indices of the characters as a `tf.Tensor`, of
216+
* dtype `int32` and shape `[numExamples, outputLength]`, where `outputLength`
217+
* is the length of the standard output format (i.e., `10`).
218+
*/
219+
export function encodeOutputDateStrings(dateStrings, oneHot = false) {
220+
const n = dateStrings.length;
221+
const x = tf.buffer([n, OUTPUT_LENGTH], 'int32');
222+
for (let i = 0; i < n; ++i) {
223+
tf.util.assert(
224+
dateStrings[i].length === OUTPUT_LENGTH,
225+
`Date string is not in ISO format: "${dateStrings[i]}"`);
226+
for (let j = 0; j < OUTPUT_LENGTH; ++j) {
227+
const char = dateStrings[i][j];
228+
const index = OUTPUT_VOCAB.indexOf(char);
229+
if (index === -1) {
230+
throw new Error(`Unknown char: ${char}`);
231+
}
232+
x.set(index, i, j);
233+
}
234+
}
235+
return x.toTensor();
236+
}

0 commit comments

Comments
 (0)