Skip to content

Commit 87262c6

Browse files
Increase unit test coverage (#201)
* cloud utils unit tests * pytorch model unit tests * tensorflow keras model unit tests * _use_legacy_schema became a static method * sklearn model unit tests * cloud utils unit tests * pytorch model unit tests * tensorflow keras model unit tests * _use_legacy_schema became a static method * sklearn model unit tests * removed tensorflow custom layer commented code * PR changes
1 parent 23db4e8 commit 87262c6

7 files changed

+40
-91
lines changed

tests/models/test_cloud_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from tiledb.ml.models._cloud_utils import (
24
get_cloud_uri,
35
get_s3_prefix,
@@ -8,6 +10,8 @@
810
class TestCloudUtils:
911
def test_get_s3_prefix(self, mocker):
1012

13+
assert get_s3_prefix(None) is None
14+
1115
profile = mocker.patch(
1216
"tiledb.cloud.client.user_profile",
1317
return_value=mocker.Mock(username="foo", default_s3_path="bar"),
@@ -47,6 +51,12 @@ def test_get_cloud_uri(self, mocker):
4751
uri="tiledb_array", namespace="test_namespace"
4852
)
4953

54+
mocker.patch("tiledb.ml.models._cloud_utils.get_s3_prefix", return_value=None)
55+
with pytest.raises(ValueError) as ex:
56+
get_cloud_uri(uri="tiledb_array", namespace="test_namespace")
57+
58+
assert "You must set the default s3 prefix path for ML models" in str(ex.value)
59+
5060
def test_update_file_properties(self, mocker):
5161
mock_tiledb_cloud_update_file_properties = mocker.patch(
5262
"tiledb.cloud.array.update_file_properties"

tests/models/test_pytorch_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ def test_save(self, tmpdir, net, optimizer):
127127
):
128128
assert all([a == b for a, b in zip(key_item_1[1], key_item_2[1])])
129129

130+
with pytest.raises(RuntimeError) as ex:
131+
tiledb_obj = PyTorchTileDBModel(uri="")
132+
tiledb_obj.save()
133+
134+
assert "Model is not initialized" in str(ex.value)
135+
130136
@net
131137
def test_preview(self, tmpdir, net):
132138
# With model given as argument

tests/models/test_sklearn_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def test_save_load(self, tmpdir, net):
3636
]
3737
)
3838

39+
with pytest.raises(RuntimeError) as ex:
40+
tiledb_obj = SklearnTileDBModel(uri="")
41+
tiledb_obj.save()
42+
43+
assert "Model is not initialized" in str(ex.value)
44+
3945
def test_preview(self, tmpdir, net):
4046
# With model as argument
4147
tiledb_array = os.path.join(tmpdir, "test_array")

tests/models/test_tensorflow_keras_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def test_save_model_to_tiledb_array(self, tmpdir, api, loss, optimizer, metrics)
111111
tiledb_model_obj.save(include_optimizer=True if optimizer else False)
112112
assert tiledb.array_exists(tiledb_uri)
113113

114+
with pytest.raises(RuntimeError) as ex:
115+
tiledb_obj = TensorflowKerasTileDBModel(uri="")
116+
tiledb_obj.save()
117+
118+
assert "Model is not initialized" in str(ex.value)
119+
114120
@api
115121
@loss_optimizer_metrics
116122
def test_save_model_to_tiledb_array_predictions(

tiledb/ml/models/_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def _load_tensorboard(self, model_array: tiledb.Array) -> None:
217217
with open(path, "wb") as f:
218218
f.write(file_bytes)
219219

220-
def _use_legacy_schema(self, model_array: tiledb.Array) -> bool:
220+
@staticmethod
221+
def _use_legacy_schema(model_array: tiledb.Array) -> bool:
221222
# TODO: Decide based on tiledb-ml version and not on schema characteristics, like "offset".
222223
return str(model_array.schema.domain.dim(0).name) != "offset"

tiledb/ml/models/sklearn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator:
6363
else:
6464
return self.__load(model_array)
6565

66-
def __load_legacy(self, model_array: tiledb.Array) -> BaseEstimator:
66+
@staticmethod
67+
def __load_legacy(model_array: tiledb.Array) -> BaseEstimator:
6768
return pickle.loads(model_array[:]["model_params"].item(0))
6869

6970
def __load(self, model_array: tiledb.Array) -> BaseEstimator:

tiledb/ml/models/tensorflow_keras.py

Lines changed: 8 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import os
77
import pickle
88
from collections import ChainMap
9-
from typing import Any, List, Mapping, Optional, Tuple
9+
from typing import Any, Mapping, Optional, Tuple
1010

1111
import keras
12-
import numpy as np
1312
import tensorflow as tf
1413

1514
import tiledb
@@ -128,7 +127,7 @@ def load(
128127
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
129128
if self._use_legacy_schema(model_array):
130129
return self.__load_legacy(
131-
model_array, compile_model, callback, custom_objects, input_shape
130+
model_array, compile_model, callback, custom_objects
132131
)
133132
else:
134133
return self.__load(model_array, compile_model, callback)
@@ -139,30 +138,15 @@ def __load_legacy(
139138
compile_model: bool,
140139
callback: bool,
141140
custom_objects: Optional[Mapping[str, Any]],
142-
input_shape: Optional[Tuple[int, ...]],
143141
) -> tf.keras.Model:
144142
model_array_results = model_array[:]
145143
model_config = json.loads(model_array.meta["model_config"])
146144
model_class = model_config["class_name"]
147145

148-
if model_class not in ("Functional", "Sequential"):
149-
with SharedObjectLoadingScope():
150-
with tf.keras.utils.CustomObjectScope(custom_objects or {}):
151-
if hasattr(model_config, "decode"):
152-
model_config = model_config.decode("utf-8")
153-
model = tf.keras.models.model_from_config(
154-
model_config, custom_objects=custom_objects
155-
)
156-
if not model.built:
157-
model.build(input_shape)
158-
159-
# Load weights for layers
160-
self._load_custom_subclassed_model(model, model_array)
161-
else:
162-
cls = tf.keras.Sequential if model_class == "Sequential" else tf.keras.Model
163-
model = cls.from_config(model_config["config"])
164-
model_weights = pickle.loads(model_array_results["model_weights"].item(0))
165-
model.set_weights(model_weights)
146+
cls = tf.keras.Sequential if model_class == "Sequential" else tf.keras.Model
147+
model = cls.from_config(model_config["config"])
148+
model_weights = pickle.loads(model_array_results["model_weights"].item(0))
149+
model.set_weights(model_weights)
166150

167151
if compile_model:
168152
optimizer_weights = pickle.loads(
@@ -198,6 +182,7 @@ def __load_legacy(
198182
"starting with a freshly initialized "
199183
"optimizer."
200184
)
185+
201186
if callback:
202187
try:
203188
with tiledb.open(f"{self.uri}-tensorboard") as tb_array:
@@ -239,6 +224,7 @@ def __load(
239224
saving_utils.try_build_compiled_arguments(model)
240225

241226
optimizer_weights = self._get_model_param(model_array, "optimizer")
227+
242228
# Set optimizer weights.
243229
if optimizer_weights:
244230
try:
@@ -284,70 +270,3 @@ def _serialize_optimizer_weights(
284270
optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights)
285271
return pickle.dumps(optimizer_weights, protocol=4)
286272
return b""
287-
288-
def _load_custom_subclassed_model(
289-
self, model: tf.keras.Model, model_array: tiledb.Array
290-
) -> None:
291-
if "keras_version" in model_array.meta:
292-
original_keras_version = model_array.meta["keras_version"]
293-
if hasattr(original_keras_version, "decode"):
294-
original_keras_version = original_keras_version.decode("utf8")
295-
else:
296-
original_keras_version = "1"
297-
if "backend" in model_array.meta:
298-
original_backend = model_array.meta["backend"]
299-
if hasattr(original_backend, "decode"):
300-
original_backend = original_backend.decode("utf8")
301-
else:
302-
original_backend = None
303-
304-
# Load weights for layers
305-
self._load_weights_from_tiledb(
306-
model_array[:], model, original_keras_version, original_backend
307-
)
308-
309-
@staticmethod
310-
def _load_weights_from_tiledb(
311-
model_array_results: Mapping[str, Any],
312-
model: tf.keras.Model,
313-
original_keras_version: Optional[str],
314-
original_backend: Optional[str],
315-
) -> None:
316-
num_layers = 0
317-
for layer in model.layers:
318-
weights = layer.trainable_weights + layer.non_trainable_weights
319-
if weights:
320-
num_layers += 1
321-
322-
read_layer_names = []
323-
for k, name in enumerate(model_array_results["layer_name"]):
324-
layer_weight_names = pickle.loads(
325-
model_array_results["weight_names"].item(k)
326-
)
327-
if layer_weight_names:
328-
read_layer_names.append(name)
329-
330-
if len(read_layer_names) != num_layers:
331-
raise ValueError(
332-
f"You are trying to load a weight file with {len(read_layer_names)} "
333-
f"layers into a model with {num_layers} layers"
334-
)
335-
336-
var_value_tuples: List[Tuple[tf.Variable, np.ndarray]] = []
337-
for k, layer in enumerate(model.layers):
338-
weight_vars = layer.trainable_weights + layer.non_trainable_weights
339-
read_weight_values = pickle.loads(
340-
model_array_results["weight_values"].item(k)
341-
)
342-
read_weight_values = preprocess_weights_for_loading(
343-
layer, read_weight_values, original_keras_version, original_backend
344-
)
345-
if len(read_weight_values) != len(weight_vars):
346-
raise ValueError(
347-
f'Layer #{k} (named "{layer.name}" in the current model) was found '
348-
f"to correspond to layer {layer} in the save file. However the new "
349-
f"layer {layer.name} expects {len(weight_vars)} weights, "
350-
f"but the saved weights have {len(read_weight_values)} elements"
351-
)
352-
var_value_tuples.extend(zip(weight_vars, read_weight_values))
353-
tf.keras.backend.batch_set_value(var_value_tuples)

0 commit comments

Comments
 (0)