Skip to content

Commit 74fea2c

Browse files
authored
Upgrade examples to tfjs-node* 1.0.1 (#247)
* Upgrade mnist-node to 1.0.1 * save * save * update date-conversion-attention * jena-weather and lstm-text-gen * mnst-acgan * sentiment * save * save * save * fix
1 parent 18cc9ce commit 74fea2c

File tree

20 files changed

+477
-981
lines changed

20 files changed

+477
-981
lines changed

baseball-node/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"webpack-dev-server": "~3.1.14"
1616
},
1717
"dependencies": {
18-
"@tensorflow/tfjs-node": "^0.2.3",
18+
"@tensorflow/tfjs-node": "^1.0.1",
1919
"argparse": "^1.0.10",
2020
"socket.io": "~2.2.0"
2121
}

baseball-node/pitch_type.js

+16-17
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,18 @@ const TEST_DATA_LENGTH = 700;
4545

4646
// Converts a row from the CSV into features and labels.
4747
// Each feature field is normalized within training data constants:
48-
const csvTransform = ([features, labels]) => {
49-
const values = [
50-
normalize(features.vx0, VX0_MIN, VX0_MAX),
51-
normalize(features.vy0, VY0_MIN, VY0_MAX),
52-
normalize(features.vz0, VZ0_MIN, VZ0_MAX),
53-
normalize(features.ax, AX_MIN, AX_MAX),
54-
normalize(features.ay, AY_MIN, AY_MAX),
55-
normalize(features.az, AZ_MIN, AZ_MAX),
56-
normalize(features.start_speed, START_SPEED_MIN, START_SPEED_MAX),
57-
features.left_handed_pitcher
58-
];
59-
return [values, [labels.pitch_code]];
60-
};
48+
const csvTransform =
49+
({xs, ys}) => {
50+
const values = [
51+
normalize(xs.vx0, VX0_MIN, VX0_MAX),
52+
normalize(xs.vy0, VY0_MIN, VY0_MAX),
53+
normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
54+
normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
55+
normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
56+
xs.left_handed_pitcher
57+
];
58+
return {xs: values, ys: ys.pitch_code};
59+
}
6160

6261
const trainingData =
6362
tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
@@ -93,8 +92,8 @@ model.compile({
9392
async function evaluate(useTestData) {
9493
// TODO(kreeger): Consider using model.evaluateDataset()
9594
let results = {};
96-
await trainingValidationData.forEach((pitchTypeBatch) => {
97-
const values = model.predict(pitchTypeBatch[0]).dataSync();
95+
await trainingValidationData.forEachAsync(pitchTypeBatch => {
96+
const values = model.predict(pitchTypeBatch.xs).dataSync();
9897
const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
9998
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
10099
results[pitchFromClassNum(i)] = {
@@ -104,8 +103,8 @@ async function evaluate(useTestData) {
104103
});
105104

106105
if (useTestData) {
107-
await testValidationData.forEach((pitchTypeBatch) => {
108-
const values = model.predict(pitchTypeBatch[0]).dataSync();
106+
await testValidationData.forEachAsync(pitchTypeBatch => {
107+
const values = model.predict(pitchTypeBatch.xs).dataSync();
109108
const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
110109
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
111110
results[pitchFromClassNum(i)].validation =

baseball-node/strike_zone.js

+5-7
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@ const TEST_DATA_LENGTH = 200;
3838

3939
// Converts a row from the CSV into features and labels.
4040
// Each feature field is normalized within training data constants:
41-
const csvTransform = ([features, labels]) => {
41+
const csvTransform = ({xs, ys}) => {
4242
const values = [
43-
normalize(features.px, PX_MIN, PX_MAX),
44-
normalize(features.pz, PZ_MIN, PZ_MAX),
45-
normalize(features.sz_top, SZ_TOP_MIN, SZ_TOP_MAX),
46-
normalize(features.sz_bot, SZ_BOT_MIN, SZ_BOT_MAX),
47-
features.left_handed_batter
43+
normalize(xs.px, PX_MIN, PX_MAX), normalize(xs.pz, PZ_MIN, PZ_MAX),
44+
normalize(xs.sz_top, SZ_TOP_MIN, SZ_TOP_MAX),
45+
normalize(xs.sz_bot, SZ_BOT_MIN, SZ_BOT_MAX), xs.left_handed_batter
4846
];
49-
return [values, [labels.is_strike]]
47+
return {xs: values, ys: ys.is_strike};
5048
};
5149

5250
const trainingData =

baseball-node/train_pitch_type.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ async function run(epochCount, savePath) {
3131
});
3232

3333
// Eval against test data:
34-
await pitch_type.testValidationData.forEach((data) => {
34+
await pitch_type.testValidationData.forEachAsync(data => {
3535
const evalOutput = pitch_type.model.evaluate(
36-
data[0], data[1], pitch_type.TEST_DATA_LENGTH);
36+
data.xs, data.ys, pitch_type.TEST_DATA_LENGTH);
3737

3838
console.log(
3939
`\nEvaluation result:\n` +

baseball-node/train_strike_zone.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ async function run(epochCount, savePath) {
3131
});
3232

3333
// Eval against test data:
34-
await sz_model.testValidationData.forEach((data) => {
34+
await sz_model.testValidationData.forEachAsync(data => {
3535
const evalOutput =
36-
sz_model.model.evaluate(data[0], data[1], sz_model.TEST_DATA_LENGTH);
36+
sz_model.model.evaluate(data.xs, data.ys, sz_model.TEST_DATA_LENGTH);
3737

3838
console.log(
3939
`\nEvaluation result:\n` +

baseball-node/yarn.lock

+30-106
Original file line numberDiff line numberDiff line change
@@ -2,87 +2,41 @@
22
# yarn lockfile v1
33

44

5-
"@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2":
6-
version "1.1.2"
7-
resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf"
8-
9-
"@protobufjs/base64@^1.1.2":
10-
version "1.1.2"
11-
resolved "https://registry.yarnpkg.com/@protobufjs/base64/-/base64-1.1.2.tgz#4c85730e59b9a1f1f349047dbf24296034bb2735"
12-
13-
"@protobufjs/codegen@^2.0.4":
14-
version "2.0.4"
15-
resolved "https://registry.yarnpkg.com/@protobufjs/codegen/-/codegen-2.0.4.tgz#7ef37f0d010fb028ad1ad59722e506d9262815cb"
16-
17-
"@protobufjs/eventemitter@^1.1.0":
18-
version "1.1.0"
19-
resolved "https://registry.yarnpkg.com/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz#355cbc98bafad5978f9ed095f397621f1d066b70"
20-
21-
"@protobufjs/fetch@^1.1.0":
22-
version "1.1.0"
23-
resolved "https://registry.yarnpkg.com/@protobufjs/fetch/-/fetch-1.1.0.tgz#ba99fb598614af65700c1619ff06d454b0d84c45"
24-
dependencies:
25-
"@protobufjs/aspromise" "^1.1.1"
26-
"@protobufjs/inquire" "^1.1.0"
27-
28-
"@protobufjs/float@^1.0.2":
29-
version "1.0.2"
30-
resolved "https://registry.yarnpkg.com/@protobufjs/float/-/float-1.0.2.tgz#5e9e1abdcb73fc0a7cb8b291df78c8cbd97b87d1"
31-
32-
"@protobufjs/inquire@^1.1.0":
33-
version "1.1.0"
34-
resolved "https://registry.yarnpkg.com/@protobufjs/inquire/-/inquire-1.1.0.tgz#ff200e3e7cf2429e2dcafc1140828e8cc638f089"
35-
36-
"@protobufjs/path@^1.1.2":
37-
version "1.1.2"
38-
resolved "https://registry.yarnpkg.com/@protobufjs/path/-/path-1.1.2.tgz#6cc2b20c5c9ad6ad0dccfd21ca7673d8d7fbf68d"
39-
40-
"@protobufjs/pool@^1.1.0":
41-
version "1.1.0"
42-
resolved "https://registry.yarnpkg.com/@protobufjs/pool/-/pool-1.1.0.tgz#09fd15f2d6d3abfa9b65bc366506d6ad7846ff54"
43-
44-
"@protobufjs/utf8@^1.1.0":
45-
version "1.1.0"
46-
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570"
47-
48-
"@tensorflow/[email protected]":
49-
version "0.7.2"
50-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-0.7.2.tgz#49e578f71eb82d821af05176754c3452b42cfe9c"
51-
integrity sha512-m46mtaF57x2NcxlNUKdJOCUp3ZSJU9bp9MzyEQ0Iz1bW2kKIxx1DDRjuP0fAeHX5H5Mh/tWIHB9yK6NwLz+aQQ==
52-
dependencies:
53-
"@types/long" "~3.0.32"
54-
protobufjs "~6.8.6"
5+
"@tensorflow/[email protected]":
6+
version "1.0.1"
7+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-1.0.1.tgz#250ce5289644c5a7bfc7eb357459dd7a9e27ba7a"
8+
integrity sha512-YpvonHCyTM8imuZU025uc2JLHITUEOvxqku01cV4N018pQnKAvbMuIC4xGRWtkTgE4+GArzR5SLEUFV0MrVjhQ==
559

56-
"@tensorflow/tfjs-core@0.14.5":
57-
version "0.14.5"
58-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.14.5.tgz#17c3beeec31c4cd92b0f79a5ef30c4975a11e408"
59-
integrity sha512-CSUgKuC17J1Ylr1s6iD1k2/tJr9lD16sUEjtzJbtiuTYCELOwujGK/1htunA7o3BwLuU7aqEI92MoKElEKa7qA==
10+
"@tensorflow/tfjs-core@1.0.1":
11+
version "1.0.1"
12+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.0.1.tgz#5348bd1b292b420b95e8591a4131d703cd7d5c3c"
13+
integrity sha512-VIr0SqsezNg/9mLc+fUNYE+0hkZo/F83Pcs9XKjWlE/mpMyjIHH5F2xnn4JAfJO5gWQLtAWHd8P7IzM+1W5r/A==
6014
dependencies:
6115
"@types/seedrandom" "2.4.27"
6216
"@types/webgl-ext" "0.0.30"
6317
"@types/webgl2" "0.0.4"
6418
seedrandom "2.4.3"
6519

66-
"@tensorflow/[email protected].7":
67-
version "0.1.7"
68-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-0.1.7.tgz#8a4e43313b3d63cdfab719c0c1c47ced2ef321e3"
69-
integrity sha512-RENjeBdBLq7GS9594kQx2GbM0WQV16VfxzzB0j2sq5vJh9GZQi2DB5Emq2LqZWs5rSeh7PDHZylGOn/ve6f8PA==
20+
"@tensorflow/tfjs-data@1.0.1":
21+
version "1.0.1"
22+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-1.0.1.tgz#63be00bb13c268daf86948d651623f7c4e541c3f"
23+
integrity sha512-XaB2Uaz5Mzgq81NfQxdA13O27LOlwl//kMLno2P8JGb4D/2I8CaNzlL7HpbBpXp2mZvdFDUZsnK/nbKTka+vqw==
7024
dependencies:
7125
"@types/node-fetch" "^2.1.2"
7226
node-fetch "~2.1.2"
7327
seedrandom "~2.4.3"
7428

75-
"@tensorflow/tfjs-layers@0.9.2":
76-
version "0.9.2"
77-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-0.9.2.tgz#f5c1918d1a9660096f259cd1f99f59e689b41b69"
78-
integrity sha512-peB824cEXRBy5IgZPIodd8zpQ/54VGOYbR+zY+Q1Le7v3Np05EoDcL8Z98MtpBHo6jOM7b/3Lf2zjfJVv2qxJA==
29+
"@tensorflow/tfjs-layers@1.0.1":
30+
version "1.0.1"
31+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-1.0.1.tgz#b24ef2fe84e5347fd562da966a1e73b99b1b5afa"
32+
integrity sha512-cI703R/SHRmBstBtA939ri9acSs6lbcDisa2+yc8YMgo38jokO6t06akKPZSZcQFK5gyusDWAYpMDxvI3lcAWA==
7933

80-
"@tensorflow/tfjs-node@^0.2.3":
81-
version "0.2.3"
82-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-node/-/tfjs-node-0.2.3.tgz#9408e7162bf27c7f8c59984c30c3141e6d8ec817"
83-
integrity sha512-+VXi6GLsVXXido2DhzK2e1Y/qM9MvQNbbA00TFgGuVbGMmeX0ey97t6W23dT8dnDVPZprC2XSFumcpRoKe8ENg==
34+
"@tensorflow/tfjs-node@^1.0.1":
35+
version "1.0.1"
36+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-node/-/tfjs-node-1.0.1.tgz#d15407d2387d4f3d6818a5a6a8df5c8157e29586"
37+
integrity sha512-lGRfG5LgHqXLvVof8Xj3PYVqpyjl/vP282G6ezvE7ikh48orcpFs1P/f+70Bf95cT1LrK0ecmg9ZWE/jftXgRA==
8438
dependencies:
85-
"@tensorflow/tfjs" "~0.14.2"
39+
"@tensorflow/tfjs" "~1.0.1"
8640
adm-zip "^0.4.11"
8741
bindings "~1.3.0"
8842
https-proxy-agent "^2.2.1"
@@ -91,23 +45,15 @@
9145
rimraf "^2.6.2"
9246
tar "^4.4.6"
9347

94-
"@tensorflow/tfjs@~0.14.2":
95-
version "0.14.2"
96-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-0.14.2.tgz#f38fa572286dadfe981c219f5639defd586c20c4"
97-
integrity sha512-d+kBdhn3L/BOIwwc44V1lUrs0O5s49ujhYXVHT9Hs6y3yq+OqPK10am16H1fNcxeMn12/3gGphebglObTD0/Sg==
48+
"@tensorflow/tfjs@~1.0.1":
49+
version "1.0.1"
50+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-1.0.1.tgz#1642d037222e90f8393f3b12ffc38a770d962515"
51+
integrity sha512-EPFnB+ihJc11npoVBm8PWLfgGcMh8KhU2y7T4hpNNDRPTOvZqD/xx5ApVV9j300IHMKcUup25S6V2e5CfTkTbg==
9852
dependencies:
99-
"@tensorflow/tfjs-converter" "0.7.2"
100-
"@tensorflow/tfjs-core" "0.14.5"
101-
"@tensorflow/tfjs-data" "0.1.7"
102-
"@tensorflow/tfjs-layers" "0.9.2"
103-
104-
"@types/long@^4.0.0":
105-
version "4.0.0"
106-
resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.0.tgz#719551d2352d301ac8b81db732acb6bdc28dbdef"
107-
108-
"@types/long@~3.0.32":
109-
version "3.0.32"
110-
resolved "https://registry.yarnpkg.com/@types/long/-/long-3.0.32.tgz#f4e5af31e9e9b196d8e5fca8a5e2e20aa3d60b69"
53+
"@tensorflow/tfjs-converter" "1.0.1"
54+
"@tensorflow/tfjs-core" "1.0.1"
55+
"@tensorflow/tfjs-data" "1.0.1"
56+
"@tensorflow/tfjs-layers" "1.0.1"
11157

11258
"@types/node-fetch@^2.1.2":
11359
version "2.1.4"
@@ -116,7 +62,7 @@
11662
dependencies:
11763
"@types/node" "*"
11864

119-
"@types/node@*", "@types/node@^10.1.0":
65+
"@types/node@*":
12066
version "10.12.0"
12167
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.12.0.tgz#ea6dcbddbc5b584c83f06c60e82736d8fbb0c235"
12268

@@ -2044,10 +1990,6 @@ loglevel@^1.4.1:
20441990
version "1.6.1"
20451991
resolved "https://registry.yarnpkg.com/loglevel/-/loglevel-1.6.1.tgz#e0fc95133b6ef276cdc8887cdaf24aa6f156f8fa"
20461992

2047-
long@^4.0.0:
2048-
version "4.0.0"
2049-
resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28"
2050-
20511993
lru-cache@^5.1.1:
20521994
version "5.1.1"
20531995
resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-5.1.1.tgz#1da27e6710271947695daf6848e847f01d84b920"
@@ -2653,24 +2595,6 @@ promise-inflight@^1.0.1:
26532595
version "1.0.1"
26542596
resolved "https://registry.yarnpkg.com/promise-inflight/-/promise-inflight-1.0.1.tgz#98472870bf228132fcbdd868129bad12c3c029e3"
26552597

2656-
protobufjs@~6.8.6:
2657-
version "6.8.8"
2658-
resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-6.8.8.tgz#c8b4f1282fd7a90e6f5b109ed11c84af82908e7c"
2659-
dependencies:
2660-
"@protobufjs/aspromise" "^1.1.2"
2661-
"@protobufjs/base64" "^1.1.2"
2662-
"@protobufjs/codegen" "^2.0.4"
2663-
"@protobufjs/eventemitter" "^1.1.0"
2664-
"@protobufjs/fetch" "^1.1.0"
2665-
"@protobufjs/float" "^1.0.2"
2666-
"@protobufjs/inquire" "^1.1.0"
2667-
"@protobufjs/path" "^1.1.2"
2668-
"@protobufjs/pool" "^1.1.0"
2669-
"@protobufjs/utf8" "^1.1.0"
2670-
"@types/long" "^4.0.0"
2671-
"@types/node" "^10.1.0"
2672-
long "^4.0.0"
2673-
26742598
proxy-addr@~2.0.4:
26752599
version "2.0.4"
26762600
resolved "https://registry.yarnpkg.com/proxy-addr/-/proxy-addr-2.0.4.tgz#ecfc733bf22ff8c6f407fa275327b9ab67e48b93"

date-conversion-attention/package.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open"
1616
},
1717
"dependencies": {
18-
"@tensorflow/tfjs": "^1.0.0",
18+
"@tensorflow/tfjs": "^1.0.1",
1919
"@tensorflow/tfjs-vis": "1.0.3"
2020
},
2121
"devDependencies": {
22-
"@tensorflow/tfjs-node": "^0.3.1",
23-
"@tensorflow/tfjs-node-gpu": "^0.3.1",
22+
"@tensorflow/tfjs-node": "^1.0.1",
23+
"@tensorflow/tfjs-node-gpu": "^1.0.1",
2424
"argparse": "^1.0.10",
2525
"babel-cli": "^6.26.0",
2626
"babel-core": "^6.26.3",

0 commit comments

Comments
 (0)