6
6
import os
7
7
import pickle
8
8
from collections import ChainMap
9
- from typing import Any , List , Mapping , Optional , Tuple
9
+ from typing import Any , Mapping , Optional , Tuple
10
10
11
11
import keras
12
- import numpy as np
13
12
import tensorflow as tf
14
13
15
14
import tiledb
@@ -128,7 +127,7 @@ def load(
128
127
with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
129
128
if self ._use_legacy_schema (model_array ):
130
129
return self .__load_legacy (
131
- model_array , compile_model , callback , custom_objects , input_shape
130
+ model_array , compile_model , callback , custom_objects
132
131
)
133
132
else :
134
133
return self .__load (model_array , compile_model , callback )
@@ -139,30 +138,15 @@ def __load_legacy(
139
138
compile_model : bool ,
140
139
callback : bool ,
141
140
custom_objects : Optional [Mapping [str , Any ]],
142
- input_shape : Optional [Tuple [int , ...]],
143
141
) -> tf .keras .Model :
144
142
model_array_results = model_array [:]
145
143
model_config = json .loads (model_array .meta ["model_config" ])
146
144
model_class = model_config ["class_name" ]
147
145
148
- if model_class not in ("Functional" , "Sequential" ):
149
- with SharedObjectLoadingScope ():
150
- with tf .keras .utils .CustomObjectScope (custom_objects or {}):
151
- if hasattr (model_config , "decode" ):
152
- model_config = model_config .decode ("utf-8" )
153
- model = tf .keras .models .model_from_config (
154
- model_config , custom_objects = custom_objects
155
- )
156
- if not model .built :
157
- model .build (input_shape )
158
-
159
- # Load weights for layers
160
- self ._load_custom_subclassed_model (model , model_array )
161
- else :
162
- cls = tf .keras .Sequential if model_class == "Sequential" else tf .keras .Model
163
- model = cls .from_config (model_config ["config" ])
164
- model_weights = pickle .loads (model_array_results ["model_weights" ].item (0 ))
165
- model .set_weights (model_weights )
146
+ cls = tf .keras .Sequential if model_class == "Sequential" else tf .keras .Model
147
+ model = cls .from_config (model_config ["config" ])
148
+ model_weights = pickle .loads (model_array_results ["model_weights" ].item (0 ))
149
+ model .set_weights (model_weights )
166
150
167
151
if compile_model :
168
152
optimizer_weights = pickle .loads (
@@ -198,6 +182,7 @@ def __load_legacy(
198
182
"starting with a freshly initialized "
199
183
"optimizer."
200
184
)
185
+
201
186
if callback :
202
187
try :
203
188
with tiledb .open (f"{ self .uri } -tensorboard" ) as tb_array :
@@ -239,6 +224,7 @@ def __load(
239
224
saving_utils .try_build_compiled_arguments (model )
240
225
241
226
optimizer_weights = self ._get_model_param (model_array , "optimizer" )
227
+
242
228
# Set optimizer weights.
243
229
if optimizer_weights :
244
230
try :
@@ -284,70 +270,3 @@ def _serialize_optimizer_weights(
284
270
optimizer_weights = tf .keras .backend .batch_get_value (optimizer .weights )
285
271
return pickle .dumps (optimizer_weights , protocol = 4 )
286
272
return b""
287
-
288
- def _load_custom_subclassed_model (
289
- self , model : tf .keras .Model , model_array : tiledb .Array
290
- ) -> None :
291
- if "keras_version" in model_array .meta :
292
- original_keras_version = model_array .meta ["keras_version" ]
293
- if hasattr (original_keras_version , "decode" ):
294
- original_keras_version = original_keras_version .decode ("utf8" )
295
- else :
296
- original_keras_version = "1"
297
- if "backend" in model_array .meta :
298
- original_backend = model_array .meta ["backend" ]
299
- if hasattr (original_backend , "decode" ):
300
- original_backend = original_backend .decode ("utf8" )
301
- else :
302
- original_backend = None
303
-
304
- # Load weights for layers
305
- self ._load_weights_from_tiledb (
306
- model_array [:], model , original_keras_version , original_backend
307
- )
308
-
309
- @staticmethod
310
- def _load_weights_from_tiledb (
311
- model_array_results : Mapping [str , Any ],
312
- model : tf .keras .Model ,
313
- original_keras_version : Optional [str ],
314
- original_backend : Optional [str ],
315
- ) -> None :
316
- num_layers = 0
317
- for layer in model .layers :
318
- weights = layer .trainable_weights + layer .non_trainable_weights
319
- if weights :
320
- num_layers += 1
321
-
322
- read_layer_names = []
323
- for k , name in enumerate (model_array_results ["layer_name" ]):
324
- layer_weight_names = pickle .loads (
325
- model_array_results ["weight_names" ].item (k )
326
- )
327
- if layer_weight_names :
328
- read_layer_names .append (name )
329
-
330
- if len (read_layer_names ) != num_layers :
331
- raise ValueError (
332
- f"You are trying to load a weight file with { len (read_layer_names )} "
333
- f"layers into a model with { num_layers } layers"
334
- )
335
-
336
- var_value_tuples : List [Tuple [tf .Variable , np .ndarray ]] = []
337
- for k , layer in enumerate (model .layers ):
338
- weight_vars = layer .trainable_weights + layer .non_trainable_weights
339
- read_weight_values = pickle .loads (
340
- model_array_results ["weight_values" ].item (k )
341
- )
342
- read_weight_values = preprocess_weights_for_loading (
343
- layer , read_weight_values , original_keras_version , original_backend
344
- )
345
- if len (read_weight_values ) != len (weight_vars ):
346
- raise ValueError (
347
- f'Layer #{ k } (named "{ layer .name } " in the current model) was found '
348
- f"to correspond to layer { layer } in the save file. However the new "
349
- f"layer { layer .name } expects { len (weight_vars )} weights, "
350
- f"but the saved weights have { len (read_weight_values )} elements"
351
- )
352
- var_value_tuples .extend (zip (weight_vars , read_weight_values ))
353
- tf .keras .backend .batch_set_value (var_value_tuples )
0 commit comments