Skip to content

Commit abca94c

Browse files
huancaisq
authored andcommitted
Add native JavaScript(TypeScript) to replace translation.py (#243)
In our [translation](https://github.com/tensorflow/tfjs-examples/tree/master/translation) example, we use a python script to generate the model, train, and save the model to file for future use. I think it will be more clear to use a native JavaScript script to do this task, so I rewrite it from Python to TypeScript. CC @wangtz
1 parent 960c1b2 commit abca94c

8 files changed

+1171
-6
lines changed

Diff for: translation/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
package-lock.json

Diff for: translation/README.md

+23-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,30 @@ API of TensorFlow.js.
66
It demonstrates loading a pretrained model hosted at a URL, using
77
`tf.loadLayersModel()`
88

9+
## Training Demo
10+
11+
The training data was 149,861 English-French sentence pairs available from [http://www.manythings.org/anki](http://www.manythings.org/anki).
12+
13+
### JavaScript/TypeScript Version
14+
15+
To train the demo in JavaScript, do
16+
17+
```sh
18+
yarn train ${DATA_PATH}
19+
```
20+
21+
The model was trained in Node.js with Tensorflow.js, which the model code is converted from Python to TypeScript by @[huan](https://github.com/huan) based on the [translation.py](https://github.com/tensorflow/tfjs-examples/blob/master/translation/python/translation.py) example.
22+
23+
### Python Version
24+
25+
```sh
26+
python python/translation.py ${DATA_PATH}
27+
```
28+
929
The model was trained in Python Keras, based on the [lstm_seq2seq](https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py)
10-
example. The training data was 149,861 English-French sentence pairs available
11-
from [http://www.manythings.org/anki](http://www.manythings.org/anki).
30+
example.
31+
32+
## LANCH DEMO
1233

1334
To launch the demo, do
1435

Diff for: translation/TRAIN.javascript.md

+471
Large diffs are not rendered by default.

Diff for: translation/TRAIN.python.md

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
$ python translation.py ~/git/javascript-concist-chit-chat/data/fra.txt
2+
Using TensorFlow backend.
3+
Number of samples: 10000
4+
Number of unique input tokens: 69
5+
Number of unique output tokens: 93
6+
Max sequence length for inputs: 16
7+
Max sequence length for outputs: 59
8+
Saved metadata at: /tmp/translation.keras/metadata.json
9+
Train on 8000 samples, validate on 2000 samples
10+
Epoch 1/20
11+
2019-03-07 23:40:53.879374: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
12+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.9242 - val_loss: 0.9666
13+
Epoch 2/20
14+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.7361 - val_loss: 0.7743
15+
Epoch 3/20
16+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.6229 - val_loss: 0.6941
17+
Epoch 4/20
18+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.5664 - val_loss: 0.6479
19+
Epoch 5/20
20+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.5270 - val_loss: 0.6120
21+
Epoch 6/20
22+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.4960 - val_loss: 0.5716
23+
Epoch 7/20
24+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.4691 - val_loss: 0.5653
25+
Epoch 8/20
26+
8000/8000 [==============================] - 23s 3ms/step - loss: 0.4458 - val_loss: 0.5349
27+
Epoch 9/20
28+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.4261 - val_loss: 0.5236
29+
Epoch 10/20
30+
8000/8000 [==============================] - 23s 3ms/step - loss: 0.4085 - val_loss: 0.5067
31+
Epoch 11/20
32+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.3930 - val_loss: 0.4963
33+
Epoch 12/20
34+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.3783 - val_loss: 0.4877
35+
Epoch 13/20
36+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.3646 - val_loss: 0.4846
37+
Epoch 14/20
38+
8000/8000 [==============================] - 23s 3ms/step - loss: 0.3517 - val_loss: 0.4780
39+
Epoch 15/20
40+
8000/8000 [==============================] - 23s 3ms/step - loss: 0.3396 - val_loss: 0.4721
41+
Epoch 16/20
42+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.3278 - val_loss: 0.4631
43+
Epoch 17/20
44+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.3168 - val_loss: 0.4683
45+
Epoch 18/20
46+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.3061 - val_loss: 0.4607
47+
Epoch 19/20
48+
8000/8000 [==============================] - 24s 3ms/step - loss: 0.2961 - val_loss: 0.4574
49+
Epoch 20/20
50+
8000/8000 [==============================] - 23s 3ms/step - loss: 0.2863 - val_loss: 0.4564
51+
WARNING:tensorflow:Layer lstm_1 was passed non-serializable keyword arguments: {'initial_state': [<tf.Tensor 'lstm/while/Exit_2:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lstm/while/Exit_3:0' shape=(?, 256) dtype=float32>]}. They will not be included in the serialized model (and thus will be missing at deserialization time).
52+
-
53+
Input sentence: Go.
54+
Target sentence: Va !
55+
Decoded sentence: Continuez à nouveau.
56+
57+
-
58+
Input sentence: Hi.
59+
Target sentence: Salut !
60+
Decoded sentence: Restez aven !
61+
62+
-
63+
Input sentence: Run!
64+
Target sentence: Cours !
65+
Decoded sentence: Sais-tou ?

Diff for: translation/build-resources.sh

+1-4
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,9 @@ mkdir -p "${RESOURCES_ROOT}"
6565

6666
python "${DEMO_DIR}/python/translation.py" \
6767
"${TRAIN_DATA_PATH}" \
68-
--recurrent_initializer glorot_uniform \
68+
--recurrent_initializer orthogonal \
6969
--artifacts_dir "${RESOURCES_ROOT}" \
7070
--epochs "${TRAIN_EPOCHS}"
71-
# TODO(cais): This --recurrent_initializer is a workaround for the limitation
72-
# in TensorFlow.js Layers that the default recurrent initializer "Orthogonal" is
73-
# currently not supported. Remove this once "Orthogonal" becomes available.
7471

7572
cd ${DEMO_DIR}
7673
yarn

Diff for: translation/package.json

+11
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
"scripts": {
1616
"watch": "./serve.sh",
1717
"build": "cross-env NODE_ENV=production parcel build index.html --no-minify --public-url ./",
18+
"train": "node -r ts-node/register translation.ts --artifacts_dir dist/resources",
1819
"link-local": "yalc link"
1920
},
2021
"devDependencies": {
22+
"@types/argparse": "^1.0.36",
23+
"@types/mkdirp": "^0.5.2",
24+
"@tensorflow/tfjs-node": "^0.3.1",
2125
"babel-core": "^6.26.3",
2226
"babel-plugin-transform-runtime": "~6.23.0",
2327
"babel-polyfill": "~6.26.0",
@@ -26,6 +30,13 @@
2630
"cross-env": "^5.1.6",
2731
"http-server": "~0.10.0",
2832
"parcel-bundler": "~1.10.3",
33+
"ts-node": "^8.0.3",
34+
"typescript": "^3.3.3333",
35+
"argparse": "^1.0.10",
36+
"invert-kv": "^2.0.0",
37+
"mkdirp": "^0.5.1",
38+
"readline": "^1.3.0",
39+
"zip-array": "^1.0.1",
2940
"yalc": "~1.0.0-pre.22"
3041
}
3142
}

0 commit comments

Comments
 (0)