Skip to content

Commit a6d5408

Browse files
authored
[snake-dqn] Fix a bug in copyWeights(); Upgrade to tfjs 1.2.7 (#323)
Workaround for tensorflow/tfjs#1807
1 parent f2fc4a6 commit a6d5408

File tree

5 files changed

+93
-67
lines changed

5 files changed

+93
-67
lines changed

Diff for: snake-dqn/agent.js

+3
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ export class SnakeGameAgent {
150150
return tf.losses.meanSquaredError(targetQs, qs);
151151
});
152152

153+
// Calculate the gradients of the loss function with repsect to the weights
154+
// of the online DQN.
153155
const grads = tf.variableGrads(lossFunction);
156+
// Use the gradients to update the online DQN's weights.
154157
optimizer.applyGradients(grads.grads);
155158
tf.dispose(grads);
156159
// TODO(cais): Return the loss value here?

Diff for: snake-dqn/dqn.js

+20
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,25 @@ export function createDeepQNetwork(h, w, numActions) {
6868
* @param {tf.LayersModel} srcNetwork The source network for weight copying.
6969
*/
7070
export function copyWeights(destNetwork, srcNetwork) {
71+
// https://github.com/tensorflow/tfjs/issues/1807:
72+
// Weight orders are inconsistent when the trainable attribute doesn't
73+
// match between two `LayersModel`s. The following is a workaround.
74+
// TODO(cais): Remove the workaround once the underlying issue is fixed.
75+
let originalDestNetworkTrainable;
76+
if (destNetwork.trainable !== srcNetwork.trainable) {
77+
originalDestNetworkTrainable = destNetwork.trainable;
78+
destNetwork.trainable = srcNetwork.trainable;
79+
}
80+
7181
destNetwork.setWeights(srcNetwork.getWeights());
82+
83+
// Weight orders are inconsistent when the trainable attribute doesn't
84+
// match between two `LayersModel`s. The following is a workaround.
85+
// TODO(cais): Remove the workaround once the underlying issue is fixed.
86+
// `originalDestNetworkTrainable` is null if and only if the `trainable`
87+
// properties of the two LayersModel instances are the same to begin
88+
// with, in which case nothing needs to be done below.
89+
if (originalDestNetworkTrainable != null) {
90+
destNetwork.trainable = originalDestNetworkTrainable;
91+
}
7292
}

Diff for: snake-dqn/dqn_test.js

+14
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,18 @@ describe('copyWeights', () => {
113113
.toEqual(0);
114114
}
115115
});
116+
117+
it('Copy from trainble source to untrainble dest works', () => {
118+
// Covers https://github.com/tensorflow/tfjs/issues/1807.
119+
const h = 9;
120+
const w = 9;
121+
const numActions = 4;
122+
const srcNetwork = createDeepQNetwork(h, w, numActions);
123+
const destNetwork = createDeepQNetwork(h, w, numActions);
124+
125+
destNetwork.trainable = false;
126+
copyWeights(destNetwork, srcNetwork);
127+
expect(destNetwork.trainable).toEqual(false);
128+
expect(srcNetwork.trainable).toEqual(true);
129+
});
116130
});

Diff for: snake-dqn/package.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12-
"@tensorflow/tfjs": "^1.2.2"
12+
"@tensorflow/tfjs": "^1.2.7"
1313
},
1414
"scripts": {
1515
"build": "cross-env NODE_ENV=production parcel build index.html --no-minify --public-url ./",
@@ -19,8 +19,8 @@
1919
"watch": "mkdir -p dist && cp -r models/dqn dist/dqn && cross-env NODE_ENV=development parcel index.html --no-hmr --open"
2020
},
2121
"devDependencies": {
22-
"@tensorflow/tfjs-node": "^1.2.3",
23-
"@tensorflow/tfjs-node-gpu": "^1.2.3",
22+
"@tensorflow/tfjs-node": "^1.2.7",
23+
"@tensorflow/tfjs-node-gpu": "^1.2.7",
2424
"argparse": "^1.0.10",
2525
"babel-cli": "^6.26.0",
2626
"babel-core": "^6.26.3",

Diff for: snake-dqn/yarn.lock

+53-64
Original file line numberDiff line numberDiff line change
@@ -677,73 +677,71 @@
677677
resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-1.1.3.tgz#2b5a3ab3f918cca48a8c754c08168e3f03eba61b"
678678
integrity sha512-shAmDyaQC4H92APFoIaVDHCx5bStIocgvbwQyxPRrbUY20V1EYTbSDchWbuwlMG3V17cprZhA6+78JfB+3DTPw==
679679

680-
"@tensorflow/[email protected].2":
681-
version "1.2.2"
682-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-1.2.2.tgz#c95e2f79b1de830b8079c7704dc8463ced2d2b79"
683-
integrity sha512-NM2NcPRHpCNeJdBxHcYpmW9ZHTQ2lJFJgmgGpQ8CxSC9CtQB05bFONs3SKcwMNDE/69QBRVom5DYqLCVUg+A+g==
680+
"@tensorflow/[email protected].7":
681+
version "1.2.7"
682+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-1.2.7.tgz#3a1cf6d636667586010c4df79e861293ce10c4c8"
683+
integrity sha512-7MRH21zNgOOYGuztcrZwx6eXhg8gn/QvIZWljsmTJ33fnIyWLSAIHXsp5WgWalfAsj40yU92TnxaIyRuPeD/mw==
684684

685-
"@tensorflow/[email protected].2":
686-
version "1.2.2"
687-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.2.2.tgz#2efa89e323612a26aeccee9b3ae9f5ac5a635bbe"
688-
integrity sha512-2hCHMKjh3UNpLEjbAEaurrTGJyj/KpLtMSAraWgHA1vGY0kmk50BBSbgCDmXWUVm7lyh/SkCq4/GrGDZktEs3g==
685+
"@tensorflow/[email protected].7":
686+
version "1.2.7"
687+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.2.7.tgz#522328de16470aa9f7c15b91e4b68616f425002a"
688+
integrity sha512-RsXavYKMc0MOcCmOyD7HE8am1tWlDGXl0nJbsdib7ubmvMuH6KnrZ302eTYV7k1RMq+/ukkioJmCcw13hopuHQ==
689689
dependencies:
690690
"@types/offscreencanvas" "~2019.3.0"
691691
"@types/seedrandom" "2.4.27"
692692
"@types/webgl-ext" "0.0.30"
693693
"@types/webgl2" "0.0.4"
694694
node-fetch "~2.1.2"
695695
seedrandom "2.4.3"
696-
optionalDependencies:
697-
rollup-plugin-visualizer "~1.1.1"
698696

699-
"@tensorflow/[email protected].2":
700-
version "1.2.2"
701-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-1.2.2.tgz#bd802b4096df04277d302d66598aef47fbffef85"
702-
integrity sha512-oHGBoGdnCl2RyouLKplQqo+iil0iJgPbi/aoHizhpO77UBuJXlKMblH8w5GbxVAw3hKxWlqzYpxPo6rVRgehNA==
697+
"@tensorflow/[email protected].7":
698+
version "1.2.7"
699+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-1.2.7.tgz#06f5315a045ca680e721688e64976a418f43f979"
700+
integrity sha512-Q274DhAHkN6wJNEeNkKj7p++lueVd0QZXWTIVoyuzN7o25dJooStjomYjZimWIi1f81/3m2AoMtfA7y0uPGDiA==
703701
dependencies:
704702
"@types/node-fetch" "^2.1.2"
705703
node-fetch "~2.1.2"
706704

707-
"@tensorflow/[email protected].2":
708-
version "1.2.2"
709-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-1.2.2.tgz#3365dbbca7cfa4fcc6cacc9fffc90d664606bd4e"
710-
integrity sha512-yzWZaZrCVpEyTkSrzMe4OOP4aGUfaaROE/zR9fPsPGGF8wLlbLNZUJjeYUmjy3G3pXGaM0mQUbLR5Vd707CVtQ==
705+
"@tensorflow/[email protected].7":
706+
version "1.2.7"
707+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-1.2.7.tgz#4d9ad8966b5f1833ddabd2581c282cd288a69968"
708+
integrity sha512-qc9qRmJsizRon2HyIKKtWnZOQFMKdaWDeLlHyZUkrOn7YxqjFvxLs9BXsjKR1IuBliOljEe+ZEePDpasUtYLng==
711709

712-
"@tensorflow/tfjs-node-gpu@^1.2.3":
713-
version "1.2.3"
714-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-node-gpu/-/tfjs-node-gpu-1.2.3.tgz#3786d814bc5ca4c10e88a4a490feea65a39bd8cf"
715-
integrity sha512-y8A1dF4WZZ+IvCCv/hrEUVV9O1ua0f5rZVzaMnJx+xv8o51DwTGk7h6tsnE/F2N6pf9mKLsY8roUBviIasVEmQ==
710+
"@tensorflow/tfjs-node-gpu@^1.2.7":
711+
version "1.2.7"
712+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-node-gpu/-/tfjs-node-gpu-1.2.7.tgz#18d5b2f173b1421344555d5d1c8f18e6111b84b0"
713+
integrity sha512-eN4lpkhrrkXSDeFCiRCzyxlA+YfIdS6jvarKsqUHFzMOjcIveN1eaBEV0FVl4bXRMfTjl8c537Y/pGiTZ6z/vA==
716714
dependencies:
717-
"@tensorflow/tfjs" "~1.2.2"
715+
"@tensorflow/tfjs" "~1.2.7"
718716
adm-zip "^0.4.11"
719-
bindings "~1.3.0"
720717
https-proxy-agent "^2.2.1"
718+
node-pre-gyp "0.13.0"
721719
progress "^2.0.0"
722720
rimraf "^2.6.2"
723721
tar "^4.4.6"
724722

725-
"@tensorflow/tfjs-node@^1.2.3":
726-
version "1.2.3"
727-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-node/-/tfjs-node-1.2.3.tgz#b2a6c3051da080a853be34b4bdc6649479139852"
728-
integrity sha512-6/V3JfoxnvUJhZle8+7V0ln7KjUIJOlDCk43EBQg+XoGudvp3L1x0RXcfCQ1nXFIlZVYixNJYd3XTIOHZBECSA==
723+
"@tensorflow/tfjs-node@^1.2.7":
724+
version "1.2.7"
725+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-node/-/tfjs-node-1.2.7.tgz#da31ca8d4cbeac083511c0f4f5caa88bd0b28ee8"
726+
integrity sha512-4JhhGcEEjiVzMPmjg7I/zMFYYWAF4oHFOkdN2WLPeb7miaZSr0+xpJwgXbT8fkm3t93YCW1vVJdeX16KZ6urLA==
729727
dependencies:
730-
"@tensorflow/tfjs" "~1.2.2"
728+
"@tensorflow/tfjs" "~1.2.7"
731729
adm-zip "^0.4.11"
732-
bindings "~1.3.0"
733730
https-proxy-agent "^2.2.1"
731+
node-pre-gyp "0.13.0"
734732
progress "^2.0.0"
735733
rimraf "^2.6.2"
736734
tar "^4.4.6"
737735

738-
"@tensorflow/tfjs@^1.2.2", "@tensorflow/tfjs@~1.2.2":
739-
version "1.2.2"
740-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-1.2.2.tgz#e0cc7f1c4139e7c38f3ea478999f0972d354c948"
741-
integrity sha512-HfhSzL2eTWhlT0r/A5wmo+u3bHe+an16p5wsnFH3ujn21fQ8QtGpSfDHQZjWx1kVFaQnV6KBG+17MOrRHoHlLA==
736+
"@tensorflow/tfjs@^1.2.7", "@tensorflow/tfjs@~1.2.7":
737+
version "1.2.7"
738+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-1.2.7.tgz#5da167a56cda8d0fb6aa2f7cf1b1d6b36a99a7f7"
739+
integrity sha512-C5wCPic5kWDHLEoDCJsuypfge/vw9CVYpHAOzM57PXhVdYBNQRFNi2mA4KhB6jg8gs/e9u5Ei4lA1I0rhPorAQ==
742740
dependencies:
743-
"@tensorflow/tfjs-converter" "1.2.2"
744-
"@tensorflow/tfjs-core" "1.2.2"
745-
"@tensorflow/tfjs-data" "1.2.2"
746-
"@tensorflow/tfjs-layers" "1.2.2"
741+
"@tensorflow/tfjs-converter" "1.2.7"
742+
"@tensorflow/tfjs-core" "1.2.7"
743+
"@tensorflow/tfjs-data" "1.2.7"
744+
"@tensorflow/tfjs-layers" "1.2.7"
747745

748746
"@types/node-fetch@^2.1.2":
749747
version "2.3.0"
@@ -1572,11 +1570,6 @@ bindings@~1.2.1:
15721570
resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.2.1.tgz#14ad6113812d2d37d72e67b4cacb4bb726505f11"
15731571
integrity sha1-FK1hE4EtLTfXLme0ystLtyZQXxE=
15741572

1575-
bindings@~1.3.0:
1576-
version "1.3.1"
1577-
resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.3.1.tgz#21fc7c6d67c18516ec5aaa2815b145ff77b26ea5"
1578-
integrity sha512-i47mqjF9UbjxJhxGf+pZ6kSxrnI3wBLlnGI2ArWJ4r0VrvDS7ZYXkprq/pLaBWYq4GM0r4zdHY+NNRqEMU7uew==
1579-
15801573
bn.js@^4.0.0, bn.js@^4.1.0, bn.js@^4.1.1, bn.js@^4.4.0:
15811574
version "4.11.8"
15821575
resolved "https://registry.yarnpkg.com/bn.js/-/bn.js-4.11.8.tgz#2cde09eb5ee341f484746bb0309b3253b1b1442f"
@@ -4117,6 +4110,22 @@ node-libs-browser@^2.0.0:
41174110
util "^0.11.0"
41184111
vm-browserify "0.0.4"
41194112

4113+
4114+
version "0.13.0"
4115+
resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.13.0.tgz#df9ab7b68dd6498137717838e4f92a33fc9daa42"
4116+
integrity sha512-Md1D3xnEne8b/HGVQkZZwV27WUi1ZRuZBij24TNaZwUPU3ZAFtvT6xxJGaUVillfmMKnn5oD1HoGsp2Ftik7SQ==
4117+
dependencies:
4118+
detect-libc "^1.0.2"
4119+
mkdirp "^0.5.1"
4120+
needle "^2.2.1"
4121+
nopt "^4.0.1"
4122+
npm-packlist "^1.1.6"
4123+
npmlog "^4.0.2"
4124+
rc "^1.2.7"
4125+
rimraf "^2.6.1"
4126+
semver "^5.3.0"
4127+
tar "^4"
4128+
41204129
node-pre-gyp@^0.10.0:
41214130
version "0.10.3"
41224131
resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.10.3.tgz#3070040716afdc778747b61b6887bf78880b80fc"
@@ -4323,7 +4332,7 @@ onetime@^2.0.0:
43234332
dependencies:
43244333
mimic-fn "^1.0.0"
43254334

4326-
opn@^5.1.0, opn@^5.4.0:
4335+
opn@^5.1.0:
43274336
version "5.5.0"
43284337
resolved "https://registry.yarnpkg.com/opn/-/opn-5.5.0.tgz#fc7164fab56d235904c51c3b27da6758ca3b9bfc"
43294338
integrity sha512-PqHpggC9bLV0VeWcdKhkpxY+3JTzetLSqTCWL/z/tFIbI6G8JCjondXklT1JinczLz2Xib62sSp0T/gKT4KksA==
@@ -5555,16 +5564,6 @@ ripemd160@^2.0.0, ripemd160@^2.0.1:
55555564
hash-base "^3.0.0"
55565565
inherits "^2.0.1"
55575566

5558-
rollup-plugin-visualizer@~1.1.1:
5559-
version "1.1.1"
5560-
resolved "https://registry.yarnpkg.com/rollup-plugin-visualizer/-/rollup-plugin-visualizer-1.1.1.tgz#454ae0aed23845407ebfb81cc52114af308d6d90"
5561-
integrity sha512-7xkSKp+dyJmSC7jg2LXqViaHuOnF1VvIFCnsZEKjrgT5ZVyiLLSbeszxFcQSfNJILphqgAEmWAUz0Z4xYScrRw==
5562-
dependencies:
5563-
mkdirp "^0.5.1"
5564-
opn "^5.4.0"
5565-
source-map "^0.7.3"
5566-
typeface-oswald "0.0.54"
5567-
55685567
safe-buffer@^5.0.1, safe-buffer@^5.1.0, safe-buffer@^5.1.1, safe-buffer@^5.1.2, safe-buffer@~5.1.0, safe-buffer@~5.1.1:
55695568
version "5.1.2"
55705569
resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d"
@@ -5815,11 +5814,6 @@ source-map@^0.5.0, source-map@^0.5.3, source-map@^0.5.6, source-map@^0.5.7:
58155814
resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.5.7.tgz#8a039d2d1021d22d1ea14c80d8ea468ba2ef3fcc"
58165815
integrity sha1-igOdLRAh0i0eoUyA2OpGi6LvP8w=
58175816

5818-
source-map@^0.7.3:
5819-
version "0.7.3"
5820-
resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.3.tgz#5302f8169031735226544092e64981f751750383"
5821-
integrity sha512-CkCj6giN3S+n9qrYiBTX5gystlENnRW5jZeNLHpe6aue+SrHcG5VYwujhW9s4dY31mEGsxBDrHR6oI69fTXsaQ==
5822-
58235817
spdx-correct@^3.0.0:
58245818
version "3.1.0"
58255819
resolved "https://registry.yarnpkg.com/spdx-correct/-/spdx-correct-3.1.0.tgz#fb83e504445268f154b074e218c87c003cd31df4"
@@ -6183,11 +6177,6 @@ typedarray@^0.0.6:
61836177
resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777"
61846178
integrity sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=
61856179

6186-
6187-
version "0.0.54"
6188-
resolved "https://registry.yarnpkg.com/typeface-oswald/-/typeface-oswald-0.0.54.tgz#1e253011622cdd50f580c04e7d625e7f449763d7"
6189-
integrity sha512-U1WMNp4qfy4/3khIfHMVAIKnNu941MXUfs3+H9R8PFgnoz42Hh9pboSFztWr86zut0eXC8byalmVhfkiKON/8Q==
6190-
61916180
unicode-canonical-property-names-ecmascript@^1.0.4:
61926181
version "1.0.4"
61936182
resolved "https://registry.yarnpkg.com/unicode-canonical-property-names-ecmascript/-/unicode-canonical-property-names-ecmascript-1.0.4.tgz#2619800c4c825800efdd8343af7dd9933cbe2818"

0 commit comments

Comments
 (0)