Skip to content

Commit a40a555

Browse files
authored
Merge pull request #210 from TileDB-Inc/sethshelnutt/sc-31839/tensorflowkerastiledbmodel-doesn-t-load
Update to support tensorflow >= 2.11
2 parents 10c8509 + aff195f commit a40a555

File tree

3 files changed

+116
-20
lines changed

3 files changed

+116
-20
lines changed

.github/workflows/ci.yml

+8-1
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,17 @@ jobs:
99
strategy:
1010
fail-fast: false
1111
matrix:
12+
python-verison: ["3.7"]
1213
ml-deps:
1314
- "torch==1.11.0+cpu torchvision==0.12.0+cpu torchdata==0.3.0 tensorflow-cpu==2.8.1"
1415
- "torch==1.12.1+cpu torchvision==0.13.1+cpu torchdata==0.4.1 tensorflow-cpu==2.9.1"
1516
- "torch==1.13.0+cpu torchvision==0.14.0+cpu torchdata==0.5.0 tensorflow-cpu==2.10.0"
17+
- "torch==1.13.0+cpu torchvision==0.14.0+cpu torchdata==0.5.0 tensorflow-cpu==2.11.0"
18+
include:
19+
- ml-deps: "torch==1.13.0+cpu torchvision==0.14.0+cpu torchdata==0.5.0 tensorflow-cpu==2.12.0"
20+
python-version: "3.9"
21+
- ml-deps: "torch==1.13.0+cpu torchvision==0.14.0+cpu torchdata==0.5.0 tensorflow-cpu==2.13.0"
22+
python-version: "3.9"
1623

1724
env:
1825
run_coverage: ${{ github.ref == 'refs/heads/master' }}
@@ -23,7 +30,7 @@ jobs:
2330
- name: Set up Python ${{ matrix.python-version }}
2431
uses: actions/setup-python@v4
2532
with:
26-
python-version: "3.7"
33+
python-version: ${{ matrix.python-version }}
2734

2835
- name: Cache dependencies
2936
uses: actions/cache@v3

tests/models/test_tensorflow_keras_models.py

+54-9
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@
2121
get_small_sequential_mlp,
2222
)
2323
except ImportError:
24-
from keras.testing_utils import get_small_functional_mlp, get_small_sequential_mlp
24+
try:
25+
from keras.testing_utils import (
26+
get_small_functional_mlp,
27+
get_small_sequential_mlp,
28+
)
29+
except ImportError:
30+
from keras.src.testing_infra.test_utils import (
31+
get_small_functional_mlp,
32+
get_small_sequential_mlp,
33+
)
2534

2635
# Suppress all Tensorflow messages
2736
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -166,8 +175,20 @@ def test_save_model_to_tiledb_array_weights(
166175
data = np.random.rand(100, 3)
167176

168177
if optimizer:
169-
model_opt_weights = batch_get_value(model.optimizer.weights)
170-
loaded_opt_weights = batch_get_value(loaded_model.optimizer.weights)
178+
if hasattr(model.optimizer, "weights"):
179+
model_opt_weights = tf.keras.backend.batch_get_value(
180+
model.optimizer.weights
181+
)
182+
else:
183+
model_opt_weights = [var.numpy() for var in model.optimizer.variables()]
184+
if hasattr(loaded_model.optimizer, "weights"):
185+
loaded_opt_weights = tf.keras.backend.batch_get_value(
186+
loaded_model.optimizer.weights
187+
)
188+
else:
189+
loaded_opt_weights = [
190+
var.numpy() for var in loaded_model.optimizer.variables()
191+
]
171192

172193
# Assert optimizer weights are equal
173194
for weight_model, weight_loaded_model in zip(
@@ -209,8 +230,20 @@ def test_save_load_with_dense_features(self, tmpdir, loss, optimizer, metrics):
209230
tiledb_model_obj.save(include_optimizer=True)
210231
loaded_model = tiledb_model_obj.load(compile_model=True)
211232

212-
model_opt_weights = batch_get_value(model.optimizer.weights)
213-
loaded_opt_weights = batch_get_value(loaded_model.optimizer.weights)
233+
if hasattr(model.optimizer, "weights"):
234+
model_opt_weights = tf.keras.backend.batch_get_value(
235+
model.optimizer.weights
236+
)
237+
else:
238+
model_opt_weights = [var.numpy() for var in model.optimizer.variables()]
239+
if hasattr(loaded_model.optimizer, "weights"):
240+
loaded_opt_weights = tf.keras.backend.batch_get_value(
241+
loaded_model.optimizer.weights
242+
)
243+
else:
244+
loaded_opt_weights = [
245+
var.numpy() for var in loaded_model.optimizer.variables()
246+
]
214247

215248
# Assert optimizer weights are equal
216249
for weight_model, weight_loaded_model in zip(
@@ -260,8 +293,20 @@ def test_save_load_with_sequence_features(self, tmpdir, loss, optimizer, metrics
260293
tiledb_model_obj.save(include_optimizer=True)
261294
loaded_model = tiledb_model_obj.load(compile_model=True)
262295

263-
model_opt_weights = batch_get_value(model.optimizer.weights)
264-
loaded_opt_weights = batch_get_value(loaded_model.optimizer.weights)
296+
if hasattr(model.optimizer, "weights"):
297+
model_opt_weights = tf.keras.backend.batch_get_value(
298+
model.optimizer.weights
299+
)
300+
else:
301+
model_opt_weights = [var.numpy() for var in model.optimizer.variables()]
302+
if hasattr(loaded_model.optimizer, "weights"):
303+
loaded_opt_weights = tf.keras.backend.batch_get_value(
304+
loaded_model.optimizer.weights
305+
)
306+
else:
307+
loaded_opt_weights = [
308+
var.numpy() for var in loaded_model.optimizer.variables()
309+
]
265310

266311
# Assert optimizer weights are equal
267312
for weight_model, weight_loaded_model in zip(
@@ -277,7 +322,7 @@ def test_save_load_with_sequence_features(self, tmpdir, loss, optimizer, metrics
277322
indices_a[:, 0] = np.arange(10)
278323
inputs_a = tf.SparseTensor(indices_a, values_a, (batch_size, timesteps, 1))
279324

280-
values_b = np.zeros(10, dtype=np.str)
325+
values_b = np.zeros(10, dtype=str)
281326
indices_b = np.zeros((10, 3), dtype=np.int64)
282327
indices_b[:, 0] = np.arange(10)
283328
inputs_b = tf.SparseTensor(indices_b, values_b, (batch_size, timesteps, 1))
@@ -310,7 +355,7 @@ def test_functional_model_save_load_with_custom_loss_and_metric(self, tmpdir):
310355
tiledb_uri = os.path.join(tmpdir, "model_array")
311356
tiledb_model_obj = TensorflowKerasTileDBModel(uri=tiledb_uri, model=model)
312357
tiledb_model_obj.save(include_optimizer=True)
313-
loaded_model = tiledb_model_obj.load(compile_model=True)
358+
loaded_model = tiledb_model_obj.load(compile_model=True, safe_mode=False)
314359

315360
# Assert all evaluation results are the same.
316361
assert all(

tiledb/ml/models/tensorflow_keras.py

+54-10
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,39 @@
1515

1616
from ._base import Meta, TileDBArtifact, Timestamp
1717

18-
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
19-
TFOptimizer = keras.optimizers.TFOptimizer
20-
get_json_type = keras.saving.saved_model.json_utils.get_json_type
21-
preprocess_weights_for_loading = keras.saving.hdf5_format.preprocess_weights_for_loading
22-
saving_utils = keras.saving.saving_utils
18+
keras_major, keras_minor, keras_patch = keras.__version__.split(".")
19+
FunctionalOrSequential = keras.models.Sequential
20+
# Handle keras <=v2.10
21+
if int(keras_major) <= 2 and int(keras_minor) <= 10:
22+
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
23+
TFOptimizer = keras.optimizers.TFOptimizer
24+
get_json_type = keras.saving.saved_model.json_utils.get_json_type
25+
preprocess_weights_for_loading = (
26+
keras.saving.hdf5_format.preprocess_weights_for_loading
27+
)
28+
saving_utils = keras.saving.saving_utils
29+
# Handle keras >=v2.11
30+
elif int(keras_major) <= 2 and int(keras_minor) <= 12:
31+
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
32+
TFOptimizer = tf.keras.optimizers.legacy.Optimizer
33+
get_json_type = keras.saving.legacy.saved_model.json_utils.get_json_type
34+
preprocess_weights_for_loading = (
35+
keras.saving.legacy.hdf5_format.preprocess_weights_for_loading
36+
)
37+
saving_utils = keras.saving.legacy.saving_utils
38+
else:
39+
from keras.src.saving.serialization_lib import SafeModeScope
40+
41+
FunctionalOrSequential = (
42+
keras.src.engine.functional.Functional,
43+
keras.src.engine.sequential.Sequential,
44+
)
45+
TFOptimizer = tf.keras.optimizers.legacy.Optimizer
46+
get_json_type = keras.src.saving.legacy.saved_model.json_utils.get_json_type
47+
preprocess_weights_for_loading = (
48+
keras.src.saving.legacy.hdf5_format.preprocess_weights_for_loading
49+
)
50+
saving_utils = keras.src.saving.legacy.saving_utils
2351

2452

2553
class TensorflowKerasTileDBModel(TileDBArtifact[tf.keras.Model]):
@@ -59,7 +87,7 @@ def save(
5987

6088
if not isinstance(self.artifact, FunctionalOrSequential):
6189
raise RuntimeError(
62-
"Subclassed Models (Custom Layers) not supported at the moment."
90+
f"Subclassed Models (Custom Layers) for {type(self.artifact)} not supported at the moment."
6391
)
6492

6593
# Used in this format only when model is Functional or Sequential
@@ -109,6 +137,7 @@ def load(
109137
custom_objects: Optional[Mapping[str, Any]] = None,
110138
input_shape: Optional[Tuple[int, ...]] = None,
111139
callback: bool = False,
140+
safe_mode: Optional[bool] = None,
112141
) -> tf.keras.Model:
113142
"""
114143
Load switch, i.e, decide between __load (TileDB-ML<=0.8.0) or __load_v2 (TileDB-ML>0.8.0).
@@ -129,7 +158,7 @@ def load(
129158
model_array, compile_model, callback, custom_objects
130159
)
131160
else:
132-
return self.__load(model_array, compile_model, callback)
161+
return self.__load(model_array, compile_model, callback, safe_mode)
133162

134163
def __load_legacy(
135164
self,
@@ -200,13 +229,25 @@ def __load_legacy(
200229
return model
201230

202231
def __load(
203-
self, model_array: tiledb.Array, compile_model: bool, callback: bool
232+
self,
233+
model_array: tiledb.Array,
234+
compile_model: bool,
235+
callback: bool,
236+
safe_mode: Optional[bool],
204237
) -> tf.keras.Model:
205238
model_config = json.loads(model_array.meta["model_config"])
206239
model_class = model_config["class_name"]
207240

208241
cls = tf.keras.Sequential if model_class == "Sequential" else tf.keras.Model
209-
model = cls.from_config(model_config["config"])
242+
243+
if int(keras_major) <= 2 and int(keras_minor) >= 13:
244+
if safe_mode is not None:
245+
with SafeModeScope(safe_mode=safe_mode):
246+
model = cls.from_config(model_config["config"])
247+
else:
248+
model = cls.from_config(model_config["config"])
249+
else:
250+
model = cls.from_config(model_config["config"])
210251
model_weights = self._get_model_param(model_array, "model")
211252
model.set_weights(model_weights)
212253

@@ -266,6 +307,9 @@ def _serialize_optimizer_weights(
266307
assert self.artifact
267308
optimizer = self.artifact.optimizer
268309
if optimizer and not isinstance(optimizer, TFOptimizer):
269-
optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights)
310+
if hasattr(optimizer, "weights"):
311+
optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights)
312+
else:
313+
optimizer_weights = [var.numpy() for var in optimizer.variables()]
270314
return pickle.dumps(optimizer_weights, protocol=4)
271315
return b""

0 commit comments

Comments
 (0)