Skip to content

Commit 20ceb6f

Browse files
thekevinscottfengwuyaopyu10055
authored
Fix bug where a model with empty weights fails to load (#7868)
* Fix bug where a model with empty weights fails to load * Address linting complaints --------- Co-authored-by: fengwuyao <[email protected]> Co-authored-by: Ping Yu <[email protected]>
1 parent 0cd53ba commit 20ceb6f

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

tfjs-layers/src/engine/container.ts

+14-6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ export interface ContainerArgs {
3636
name?: string;
3737
}
3838

39+
// get weights key from tensor map in order to check if it is from keras v3.
40+
// e.g. dense/0
41+
const isKerasSavedModelFormat = (weights: NamedTensorMap): boolean => {
42+
const keys = Object.keys(weights);
43+
if (keys.length === 0) {
44+
return false;
45+
}
46+
const key = keys[0].split('/');
47+
return !isNaN(parseInt(key[key.length - 1], 10));
48+
};
49+
3950
/**
4051
* A Container is a directed acyclic graph of layers.
4152
*
@@ -594,19 +605,16 @@ export abstract class Container extends Layer {
594605
loadWeights(weights: NamedTensorMap, strict = true) {
595606
const nameToWeight: {[name: string]: LayerVariable} = {};
596607
let totalWeightsCount = 0;
597-
// get weights key from tensor map in order to check if it is from keras v3.
598-
// e.g. dense/0
599-
const key = Object.keys(weights)[0].split('/');
600-
const isKerasSavedModelFormat = !isNaN(parseInt(key[key.length - 1], 10));
601-
if (isKerasSavedModelFormat) {
608+
const modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights);
609+
if (modelIsKerasSavedModelFormat) {
602610
this.parseWeights(weights);
603611
}
604612
// Check if weights from keras v3.
605613
for (const layer of this.layers) {
606614
for (const [index, weight] of layer.weights.entries()) {
607615
// Parse the name to layerName/index.
608616
// e.g. dense/0, dense/1, dense_1/0, dense_1/1
609-
const parsedName = isKerasSavedModelFormat ?
617+
const parsedName = modelIsKerasSavedModelFormat ?
610618
`${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` :
611619
weight.originalName;
612620
if (nameToWeight[parsedName] != null) {

tfjs-layers/src/model_save_test.ts

+32
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,38 @@ describeMathCPUAndWebGL2('Save-load round trips', () => {
140140
}
141141
});
142142

143+
it('loadLayersModel: save and load a model with empty weights', async () => {
144+
// https://github.com/tensorflow/tfjs/issues/7865
145+
// Models without weights should still be valid models
146+
const model1 = tfl.sequential();
147+
model1.add(
148+
tfl.layers.upSampling2d({
149+
size: [2, 2],
150+
dataFormat: 'channelsLast',
151+
inputShape: [null, null, 3],
152+
})
153+
);
154+
155+
// Use a randomly generated model path to prevent collision.
156+
const path = `testModel${new Date().getTime()}_${Math.random()}`;
157+
158+
// First save the model to local storage.
159+
const modelURL = `localstorage://${path}`;
160+
await model1.save(modelURL);
161+
// Once the saving succeeds, load the model back.
162+
const model2 = await tfl.loadLayersModel(modelURL);
163+
// Verify that the topology of the model is correct.
164+
expect(model2.toJSON(null, false)).toEqual(model1.toJSON(null, false));
165+
166+
// Check the equality of the two models' weights.
167+
const weights1 = model1.getWeights();
168+
const weights2 = model2.getWeights();
169+
expect(weights2.length).toEqual(weights1.length);
170+
for (let i = 0; i < weights1.length; ++i) {
171+
expectTensorsClose(weights1[i], weights2[i]);
172+
}
173+
});
174+
143175
it('Functional model, IndexedDB', async () => {
144176
const input = tfl.input({shape: [2, 2]});
145177
const layer1 = tfl.layers.flatten().apply(input);

0 commit comments

Comments
 (0)