Skip to content

Commit b789b5e

Browse files
authored
[cart-pole] Increase network size; Update to tfjs 1.1.0 (#272)
* [cart-pole] Increase network size; Update to tfjs 1.1.0 * Change some variable names for better clarity
1 parent b3b28a3 commit b789b5e

File tree

4 files changed

+72
-41
lines changed

4 files changed

+72
-41
lines changed

cart-pole/index.html

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ <h1>TensorFlow.js: Reinforcement Learning</h1>
120120
<div class="with-cols">
121121
<div class="with-rows init-model">
122122
<div class="input-div with-rows">
123-
<label class="input-label">Hidden layer size(s) (e.g.: "5", "8,6"):</label>
124-
<input id="hidden-layer-sizes" value="4"></input>
123+
<label class="input-label">Hidden layer size(s) (e.g.: "256", "32,64"):</label>
124+
<input id="hidden-layer-sizes" value="128"></input>
125125
</div>
126126
<button id="create-model" disabled="true">Create model</button>
127127
</div>
@@ -175,14 +175,15 @@ <h1>TensorFlow.js: Reinforcement Learning</h1>
175175
<section>
176176
<p class='section-head'>Training Progress</p>
177177
<div class="with-rows">
178-
<div class="status">
179-
<label id="iteration-status">Game #:</label>
180-
<progress value="0" max="100" id="iteration-progress"></progress>
181-
</div>
182178
<div class="status">
183179
<label id="train-status">Iteration #:</label>
184180
<progress value="0" max="100" id="train-progress"></progress>
185181
</div>
182+
<div class="status">
183+
<label id="iteration-status">Game #:</label>
184+
<progress value="0" max="100" id="iteration-progress"></progress>
185+
</div>
186+
186187
<div class="status">
187188
<label>Training speed:</label>
188189
<span id="train-speed" class="status-span"></span>

cart-pole/index.js

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ class PolicyNetwork {
6060
*/
6161
constructor(hiddenLayerSizesOrModel) {
6262
if (hiddenLayerSizesOrModel instanceof tf.LayersModel) {
63-
this.model = hiddenLayerSizesOrModel;
63+
this.policyNet = hiddenLayerSizesOrModel;
6464
} else {
65-
this.createModel(hiddenLayerSizesOrModel);
65+
this.createPolicyNetwork(hiddenLayerSizesOrModel);
6666
}
6767
}
6868

@@ -73,13 +73,13 @@ class PolicyNetwork {
7373
* a single number (for a single hidden layer) or an Array of numbers (for
7474
* any number of hidden layers).
7575
*/
76-
createModel(hiddenLayerSizes) {
76+
createPolicyNetwork(hiddenLayerSizes) {
7777
if (!Array.isArray(hiddenLayerSizes)) {
7878
hiddenLayerSizes = [hiddenLayerSizes];
7979
}
80-
this.model = tf.sequential();
80+
this.policyNet = tf.sequential();
8181
hiddenLayerSizes.forEach((hiddenLayerSize, i) => {
82-
this.model.add(tf.layers.dense({
82+
this.policyNet.add(tf.layers.dense({
8383
units: hiddenLayerSize,
8484
activation: 'elu',
8585
// `inputShape` is required only for the first layer.
@@ -88,7 +88,7 @@ class PolicyNetwork {
8888
});
8989
// The last layer has only one unit. The single output number will be
9090
// converted to a probability of selecting the leftward-force action.
91-
this.model.add(tf.layers.dense({units: 1}));
91+
this.policyNet.add(tf.layers.dense({units: 1}));
9292
}
9393

9494
/**
@@ -203,7 +203,7 @@ class PolicyNetwork {
203203
*/
204204
getLogitsAndActions(inputs) {
205205
return tf.tidy(() => {
206-
const logits = this.model.predict(inputs);
206+
const logits = this.policyNet.predict(inputs);
207207

208208
// Get the probability of the leftward action.
209209
const leftProb = tf.sigmoid(logits);
@@ -265,7 +265,7 @@ export class SaveablePolicyNetwork extends PolicyNetwork {
265265
* Save the model to IndexedDB.
266266
*/
267267
async saveModel() {
268-
return await this.model.save(MODEL_SAVE_PATH_);
268+
return await this.policyNet.save(MODEL_SAVE_PATH_);
269269
}
270270

271271
/**
@@ -314,8 +314,8 @@ export class SaveablePolicyNetwork extends PolicyNetwork {
314314
*/
315315
hiddenLayerSizes() {
316316
const sizes = [];
317-
for (let i = 0; i < this.model.layers.length - 1; ++i) {
318-
sizes.push(this.model.layers[i].units);
317+
for (let i = 0; i < this.policyNet.layers.length - 1; ++i) {
318+
sizes.push(this.policyNet.layers[i].units);
319319
}
320320
return sizes.length === 1 ? sizes[0] : sizes;
321321
}

cart-pole/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12-
"@tensorflow/tfjs": "^1.0.4",
12+
"@tensorflow/tfjs": "^1.1.0",
1313
"@tensorflow/tfjs-vis": "^1.0.3"
1414
},
1515
"scripts": {

cart-pole/yarn.lock

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -652,33 +652,36 @@
652652
resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-1.1.3.tgz#2b5a3ab3f918cca48a8c754c08168e3f03eba61b"
653653
integrity sha512-shAmDyaQC4H92APFoIaVDHCx5bStIocgvbwQyxPRrbUY20V1EYTbSDchWbuwlMG3V17cprZhA6+78JfB+3DTPw==
654654

655-
"@tensorflow/tfjs-converter@1.0.4":
656-
version "1.0.4"
657-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-1.0.4.tgz#824f546bf47ceaedf7affbadff3bdd151b530066"
658-
integrity sha512-r619cIleJy9btVN12KoFCi0HIgi4qzzgxjxQCDxaVgi7Axas7MwBtAKGTZppR4+eK7eLOVYT1ah95JUDdFNGRw==
655+
"@tensorflow/tfjs-converter@1.1.0":
656+
version "1.1.0"
657+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-1.1.0.tgz#1d6f58347e9b3826c02090e06e6590b0c4df2d4e"
658+
integrity sha512-gUkoRoYm9yrVVQNp8nD+pEWOPUNhayCSrUHNItSfIm8Lzbgx6brVxVdz5T8V0kT0yh67Pp9Er/LIlf54p7KikA==
659659

660-
"@tensorflow/tfjs-core@1.0.4":
661-
version "1.0.4"
662-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.0.4.tgz#98365d04d179b3244ec59a270ec5e31e21de3b35"
663-
integrity sha512-PpdrCozkvRx3hFyHogZFjbZUmE1t/SB/iiE8yj195aCy3X2D8T66+aLS2iUkdU2tr+1bccyvPkBktwX6/lKtAw==
660+
"@tensorflow/tfjs-core@1.1.0":
661+
version "1.1.0"
662+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.1.0.tgz#028c69291e19c328c4c30e18d29b09135c22cb44"
663+
integrity sha512-loPpHGVjiyEb+Ixlsj8prQ/r4exekITn7vM4WEyHUouFKx0/CuoB2FQ0m6DSb/6ApvucxTWGGNTRRo4HK4Ma0Q==
664664
dependencies:
665665
"@types/seedrandom" "2.4.27"
666666
"@types/webgl-ext" "0.0.30"
667667
"@types/webgl2" "0.0.4"
668+
node-fetch "~2.1.2"
668669
seedrandom "2.4.3"
670+
optionalDependencies:
671+
rollup-plugin-visualizer "~1.1.1"
669672

670-
"@tensorflow/tfjs-data@1.0.4":
671-
version "1.0.4"
672-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-1.0.4.tgz#1229a3217b529933a97b577004c95dce060a49e7"
673-
integrity sha512-Rf7M8ctirYfuKh+hu96pXi6g6yChEEELOvlCa1ta4JWqbQhpsu77DL7SrEtKeWzL+aD/GdCxmwuyofSnRSO7JA==
673+
"@tensorflow/tfjs-data@1.1.0":
674+
version "1.1.0"
675+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-1.1.0.tgz#8d9a0175497930061532c8d43b419b25b26b9bbf"
676+
integrity sha512-0+PfAsaZs/pmaxiLunb4c1rPRdu47+CYe5kxpu2P8Xn3k+vhlBYMu+zsVgs5RrTRFLWVzVeH9muA1SJLkMGZPA==
674677
dependencies:
675678
"@types/node-fetch" "^2.1.2"
676679
node-fetch "~2.1.2"
677680

678-
"@tensorflow/tfjs-layers@1.0.4":
679-
version "1.0.4"
680-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-1.0.4.tgz#4ebeb3183b3705f7dd00f1cd47a33ff3a0c46213"
681-
integrity sha512-X/iH0dh7AODZO1k44cx0YO19kHAItgUc9LEf5ZYH7bWaqrL0O+J4JYONqW5xzjLHunOFkKsvtNvZi9hIyemjGA==
681+
"@tensorflow/tfjs-layers@1.1.0":
682+
version "1.1.0"
683+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-1.1.0.tgz#fd221c254d2fca13e93e83669bdde3140e7a0434"
684+
integrity sha512-a0gXjOWvGi9gc2q8/gK79zfD5WqEZnAhZfpm6b7AoKXjDUBq4GgdbbWCfv2nYBlmMoXgRSRSV44UmJVExep0uw==
682685

683686
"@tensorflow/tfjs-vis@^1.0.3":
684687
version "1.0.3"
@@ -692,15 +695,15 @@
692695
preact "^8.2.9"
693696
vega-embed "3.30.0"
694697

695-
"@tensorflow/tfjs@^1.0.4":
696-
version "1.0.4"
697-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-1.0.4.tgz#b571df49419e46753c6aa054d32c10787e83d928"
698-
integrity sha512-lTKYfk8VHG/acun7xRWz/A8FFsf2FK6aI0TvKrenmiaW4OZue8BYrFNbhLKdscSaEGvYZ7/RFWaOQkA2nrcI3Q==
698+
"@tensorflow/tfjs@^1.1.0":
699+
version "1.1.0"
700+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-1.1.0.tgz#77809dc336655a7ff0bbf76527cc3d7b9e68e330"
701+
integrity sha512-CxcFzl2KtknO3f12xuuv8kq8usMA7xGWpJajubIlYBp4KoBDhiDimP/DwBlTvFZq5RT5riHGtA4BWjMj6rnDcw==
699702
dependencies:
700-
"@tensorflow/tfjs-converter" "1.0.4"
701-
"@tensorflow/tfjs-core" "1.0.4"
702-
"@tensorflow/tfjs-data" "1.0.4"
703-
"@tensorflow/tfjs-layers" "1.0.4"
703+
"@tensorflow/tfjs-converter" "1.1.0"
704+
"@tensorflow/tfjs-core" "1.1.0"
705+
"@tensorflow/tfjs-data" "1.1.0"
706+
"@tensorflow/tfjs-layers" "1.1.0"
704707

705708
"@types/clone@^0.1.30":
706709
version "0.1.30"
@@ -4355,6 +4358,13 @@ opn@^5.1.0:
43554358
dependencies:
43564359
is-wsl "^1.1.0"
43574360

4361+
opn@^5.4.0:
4362+
version "5.5.0"
4363+
resolved "https://registry.yarnpkg.com/opn/-/opn-5.5.0.tgz#fc7164fab56d235904c51c3b27da6758ca3b9bfc"
4364+
integrity sha512-PqHpggC9bLV0VeWcdKhkpxY+3JTzetLSqTCWL/z/tFIbI6G8JCjondXklT1JinczLz2Xib62sSp0T/gKT4KksA==
4365+
dependencies:
4366+
is-wsl "^1.1.0"
4367+
43584368
43594369
version "0.6.1"
43604370
resolved "https://registry.yarnpkg.com/optimist/-/optimist-0.6.1.tgz#da3ea74686fa21a19a111c326e90eb15a0196686"
@@ -5613,6 +5623,16 @@ ripemd160@^2.0.0, ripemd160@^2.0.1:
56135623
hash-base "^3.0.0"
56145624
inherits "^2.0.1"
56155625

5626+
rollup-plugin-visualizer@~1.1.1:
5627+
version "1.1.1"
5628+
resolved "https://registry.yarnpkg.com/rollup-plugin-visualizer/-/rollup-plugin-visualizer-1.1.1.tgz#454ae0aed23845407ebfb81cc52114af308d6d90"
5629+
integrity sha512-7xkSKp+dyJmSC7jg2LXqViaHuOnF1VvIFCnsZEKjrgT5ZVyiLLSbeszxFcQSfNJILphqgAEmWAUz0Z4xYScrRw==
5630+
dependencies:
5631+
mkdirp "^0.5.1"
5632+
opn "^5.4.0"
5633+
source-map "^0.7.3"
5634+
typeface-oswald "0.0.54"
5635+
56165636
rw@1:
56175637
version "1.3.3"
56185638
resolved "https://registry.yarnpkg.com/rw/-/rw-1.3.3.tgz#3f862dfa91ab766b14885ef4d01124bfda074fb4"
@@ -5842,6 +5862,11 @@ source-map@^0.5.0, source-map@^0.5.3, source-map@^0.5.6:
58425862
resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.5.7.tgz#8a039d2d1021d22d1ea14c80d8ea468ba2ef3fcc"
58435863
integrity sha1-igOdLRAh0i0eoUyA2OpGi6LvP8w=
58445864

5865+
source-map@^0.7.3:
5866+
version "0.7.3"
5867+
resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.3.tgz#5302f8169031735226544092e64981f751750383"
5868+
integrity sha512-CkCj6giN3S+n9qrYiBTX5gystlENnRW5jZeNLHpe6aue+SrHcG5VYwujhW9s4dY31mEGsxBDrHR6oI69fTXsaQ==
5869+
58455870
spdx-correct@^3.0.0:
58465871
version "3.1.0"
58475872
resolved "https://registry.yarnpkg.com/spdx-correct/-/spdx-correct-3.1.0.tgz#fb83e504445268f154b074e218c87c003cd31df4"
@@ -6213,6 +6238,11 @@ typedarray@^0.0.6:
62136238
resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777"
62146239
integrity sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=
62156240

6241+
6242+
version "0.0.54"
6243+
resolved "https://registry.yarnpkg.com/typeface-oswald/-/typeface-oswald-0.0.54.tgz#1e253011622cdd50f580c04e7d625e7f449763d7"
6244+
integrity sha512-U1WMNp4qfy4/3khIfHMVAIKnNu941MXUfs3+H9R8PFgnoz42Hh9pboSFztWr86zut0eXC8byalmVhfkiKON/8Q==
6245+
62166246
typescript@^2.9.2:
62176247
version "2.9.2"
62186248
resolved "https://registry.yarnpkg.com/typescript/-/typescript-2.9.2.tgz#1cbf61d05d6b96269244eb6a3bce4bd914e0f00c"

0 commit comments

Comments
 (0)