Skip to content

Commit 21d2eb8

Browse files
authored
[mnist-transfer-cnn] Use modelFromJSON; update dependencies (#333)
- Replace the ad-hoc code that duplicates the model topology (while re-initializing weights) with modelFromJSON. This resolves a TODO item. - Update tfjs and tfjs-vis versions to the latest.
1 parent 83cc0f1 commit 21d2eb8

File tree

3 files changed

+531
-466
lines changed

3 files changed

+531
-466
lines changed

mnist-transfer-cnn/index.js

+6-19
Original file line numberDiff line numberDiff line change
@@ -121,25 +121,12 @@ class MnistTransferCNNPredictor {
121121
this.model.layers[i].trainable = false;
122122
}
123123
} else if (trainingMode === 'reinitialize-weights') {
124-
// TODO(cais): Use tf.models.modelFromJSON() once it's available in the
125-
// public API.
126-
const oldLayers = this.model.layers;
127-
this.model = tf.sequential();
128-
for (const layer of oldLayers) {
129-
const layerType = layer.getClassName();
130-
const layerTypeMap = {
131-
'Activation': 'activation',
132-
'Conv2D': 'conv2d',
133-
'Dense': 'dense',
134-
'Dropout': 'dropout',
135-
'Flatten': 'flatten',
136-
'MaxPooling2D': 'maxPooling2d'
137-
};
138-
const jsLayerType = layerTypeMap[layerType];
139-
this.model.add(tf.layers[jsLayerType](layer.getConfig()));
140-
}
141-
// TODO(cais): Use tfVis.show.modelSummary().
142-
this.model.summary();
124+
// Make a model with the same topology as before, but with re-initialized
125+
// weight values.
126+
const returnString = false;
127+
this.model = await tf.models.modelFromJSON({
128+
modelTopology: this.model.toJSON(null, returnString)
129+
});
143130
}
144131
this.model.compile({
145132
loss: 'categoricalCrossentropy',

mnist-transfer-cnn/package.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12-
"@tensorflow/tfjs": "^1.2.8",
13-
"@tensorflow/tfjs-vis": "^1.1.0"
12+
"@tensorflow/tfjs": "^1.2.9",
13+
"@tensorflow/tfjs-vis": "^1.2.0"
1414
},
1515
"scripts": {
1616
"watch": "./serve.sh",

0 commit comments

Comments
 (0)