@@ -36,6 +36,17 @@ export interface ContainerArgs {
36
36
name ?: string ;
37
37
}
38
38
39
+ // get weights key from tensor map in order to check if it is from keras v3.
40
+ // e.g. dense/0
41
+ const isKerasSavedModelFormat = ( weights : NamedTensorMap ) : boolean => {
42
+ const keys = Object . keys ( weights ) ;
43
+ if ( keys . length === 0 ) {
44
+ return false ;
45
+ }
46
+ const key = keys [ 0 ] . split ( '/' ) ;
47
+ return ! isNaN ( parseInt ( key [ key . length - 1 ] , 10 ) ) ;
48
+ } ;
49
+
39
50
/**
40
51
* A Container is a directed acyclic graph of layers.
41
52
*
@@ -594,19 +605,16 @@ export abstract class Container extends Layer {
594
605
loadWeights ( weights : NamedTensorMap , strict = true ) {
595
606
const nameToWeight : { [ name : string ] : LayerVariable } = { } ;
596
607
let totalWeightsCount = 0 ;
597
- // get weights key from tensor map in order to check if it is from keras v3.
598
- // e.g. dense/0
599
- const key = Object . keys ( weights ) [ 0 ] . split ( '/' ) ;
600
- const isKerasSavedModelFormat = ! isNaN ( parseInt ( key [ key . length - 1 ] , 10 ) ) ;
601
- if ( isKerasSavedModelFormat ) {
608
+ const modelIsKerasSavedModelFormat = isKerasSavedModelFormat ( weights ) ;
609
+ if ( modelIsKerasSavedModelFormat ) {
602
610
this . parseWeights ( weights ) ;
603
611
}
604
612
// Check if weights from keras v3.
605
613
for ( const layer of this . layers ) {
606
614
for ( const [ index , weight ] of layer . weights . entries ( ) ) {
607
615
// Parse the name to layerName/index.
608
616
// e.g. dense/0, dense/1, dense_1/0, dense_1/1
609
- const parsedName = isKerasSavedModelFormat ?
617
+ const parsedName = modelIsKerasSavedModelFormat ?
610
618
`${ weight . name . split ( '/' ) . slice ( 0 , - 1 ) . join ( '/' ) + '/' } ${ index } ` :
611
619
weight . originalName ;
612
620
if ( nameToWeight [ parsedName ] != null ) {
0 commit comments