Skip to content

Commit 92f594d

Browse files
authored
Moderate model refactoring (#199)
* Inline _get_file_properties * TileDBArtifact: Merge _write_model_metadata into _write_array * TensorflowKerasTileDBModel: pass model_metadata to the _write_array() call * Pass tensorboard_log_dir to _write_array * Refactor _load_tensorboard to use _get_model_param * _get_model_param: Fetch only the key attribute * Open/close the tiledb array only once per load call
1 parent bf5644e commit 92f594d

File tree

4 files changed

+194
-313
lines changed

4 files changed

+194
-313
lines changed

tiledb/ml/models/_base.py

Lines changed: 59 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@ def __init__(
6565
self.ctx = ctx
6666
self.artifact = artifact
6767
self.uri = get_cloud_uri(uri, namespace) if namespace else uri
68-
self._file_properties = self._get_file_properties()
68+
self._file_properties = {
69+
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK.value: self.Name,
70+
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION.value: self.Version,
71+
ModelFileProperties.TILEDB_ML_MODEL_STAGE.value: "STAGING",
72+
ModelFileProperties.TILEDB_ML_MODEL_PYTHON_VERSION.value: platform.python_version(),
73+
ModelFileProperties.TILEDB_ML_MODEL_PREVIEW.value: self.preview(),
74+
ModelFileProperties.TILEDB_ML_MODEL_VERSION.value: __version__,
75+
}
6976

7077
@abstractmethod
7178
def save(self, *, update: bool = False, meta: Optional[Meta] = None) -> None:
@@ -88,34 +95,23 @@ def get_weights(self, timestamp: Optional[Timestamp] = None) -> Weights:
8895
"""
8996
Returns model's weights. Works for Tensorflow Keras and PyTorch
9097
"""
91-
return cast(Weights, self._get_model_param("model", timestamp))
98+
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
99+
return cast(Weights, self._get_model_param(model_array, "model"))
92100

93101
def get_optimizer_weights(self, timestamp: Optional[Timestamp] = None) -> Weights:
94102
"""
95103
Returns optimizer's weights. Works for Tensorflow Keras and PyTorch
96104
"""
97-
return cast(Weights, self._get_model_param("optimizer", timestamp))
105+
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
106+
return cast(Weights, self._get_model_param(model_array, "optimizer"))
98107

99108
@abstractmethod
100109
def preview(self) -> str:
101110
"""
102111
Creates a string representation of a machine learning model.
103112
"""
104113

105-
def _get_file_properties(self) -> Mapping[str, str]:
106-
return {
107-
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK.value: self.Name,
108-
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION.value: self.Version,
109-
ModelFileProperties.TILEDB_ML_MODEL_STAGE.value: "STAGING",
110-
ModelFileProperties.TILEDB_ML_MODEL_PYTHON_VERSION.value: platform.python_version(),
111-
ModelFileProperties.TILEDB_ML_MODEL_PREVIEW.value: self.preview(),
112-
ModelFileProperties.TILEDB_ML_MODEL_VERSION.value: __version__,
113-
}
114-
115-
def _create_array(
116-
self,
117-
fields: Sequence[str],
118-
) -> None:
114+
def _create_array(self, fields: Sequence[str]) -> None:
119115
"""Internal method that creates a TileDB array based on the model's spec."""
120116

121117
# The array will be be 1 dimensional with domain of 0 to max uint64. We use a tile extent of 1024 bytes
@@ -152,101 +148,78 @@ def _create_array(
152148
if self.namespace:
153149
update_file_properties(self.uri, self._file_properties)
154150

155-
def _write_array(self, model_params: Mapping[str, bytes]) -> None:
156-
"""
157-
Writes machine learning model related data, i.e., model weights, optimizer weights and Tensorboard files, to
158-
a dense TileDB array.
159-
"""
151+
def _write_array(
152+
self,
153+
model_params: Mapping[str, bytes],
154+
tensorboard_log_dir: Optional[str] = None,
155+
meta: Optional[Meta] = None,
156+
) -> None:
157+
if tensorboard_log_dir:
158+
tensorboard = self._serialize_tensorboard(tensorboard_log_dir)
159+
else:
160+
tensorboard = b""
161+
model_params = dict(tensorboard=tensorboard, **model_params)
162+
163+
if meta is None:
164+
meta = {}
165+
if not meta.keys().isdisjoint(self._file_properties.keys()):
166+
raise ValueError(
167+
"Please avoid using file property key names as metadata keys!"
168+
)
160169

161170
with tiledb.open(self.uri, "w", ctx=self.ctx) as model_array:
162171
one_d_buffers = {}
163172
max_len = 0
164-
165173
for key, value in model_params.items():
166174
one_d_buffer = np.frombuffer(value, dtype=np.uint8)
167175
one_d_buffer_len = len(one_d_buffer)
168176
one_d_buffers[key] = one_d_buffer
169-
170177
# Write size only in case is greater than 0.
171178
if one_d_buffer_len:
172179
model_array.meta[key + "_size"] = one_d_buffer_len
173-
174180
if one_d_buffer_len > max_len:
175181
max_len = one_d_buffer_len
176182

177183
model_array[0:max_len] = {
178184
key: np.pad(value, (0, max_len - len(value)))
179185
for key, value in one_d_buffers.items()
180186
}
181-
182-
def _write_model_metadata(self, meta: Meta) -> None:
183-
"""
184-
Update the metadata in a TileDB model array. File properties also go in the metadata section.
185-
:param meta: A mapping with the <key, value> pairs to be inserted in array's metadata.
186-
"""
187-
with tiledb.open(self.uri, "w", ctx=self.ctx) as model_array:
188-
# Raise ValueError in case users provide metadata with the same keys as file properties.
189-
if not meta.keys().isdisjoint(self._file_properties.keys()):
190-
raise ValueError(
191-
"Please avoid using file property key names as metadata keys!"
192-
)
193-
194-
for key, value in meta.items():
195-
model_array.meta[key] = value
196-
197-
for key, value in self._file_properties.items():
198-
model_array.meta[key] = value
187+
for mapping in meta, self._file_properties:
188+
for key, value in mapping.items():
189+
model_array.meta[key] = value
190+
191+
def _get_model_param(self, model_array: tiledb.Array, key: str) -> Any:
192+
size_key = key + "_size"
193+
try:
194+
size = model_array.meta[size_key]
195+
except KeyError:
196+
raise Exception(
197+
f"{size_key} metadata entry not present in {self.uri}"
198+
f" (existing keys: {set(model_array.meta.keys())})"
199+
)
200+
return pickle.loads(model_array.query(attrs=(key,))[0:size][key].tobytes())
199201

200202
@staticmethod
201-
def _serialize_tensorboard_files(log_dir: str) -> bytes:
203+
def _serialize_tensorboard(log_dir: str) -> bytes:
202204
"""Serialize all Tensorboard files."""
203-
204205
if not os.path.exists(log_dir):
205206
raise ValueError(f"{log_dir} does not exist")
206-
207-
event_files = {}
207+
tensorboard_files = {}
208208
for path in glob.glob(f"{log_dir}/*tfevents*"):
209209
with open(path, "rb") as f:
210-
event_files[path] = f.read()
210+
tensorboard_files[path] = f.read()
211+
return pickle.dumps(tensorboard_files, protocol=4)
211212

212-
return pickle.dumps(event_files, protocol=4)
213-
214-
def _get_model_param(self, key: str, timestamp: Optional[Timestamp]) -> Any:
215-
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
216-
size_key = key + "_size"
217-
try:
218-
size = model_array.meta[size_key]
219-
except KeyError:
220-
raise Exception(
221-
f"{size_key} metadata entry not present in {self.uri}"
222-
f" (existing keys: {set(model_array.meta.keys())})"
223-
)
224-
return pickle.loads(model_array[0:size][key].tobytes())
225-
226-
def _load_tensorboard(self, timestamp: Optional[Timestamp] = None) -> None:
213+
def _load_tensorboard(self, model_array: tiledb.Array) -> None:
227214
"""
228-
Writes Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
215+
Write Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
229216
"""
230-
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
231-
try:
232-
tensorboard_size = model_array.meta["tensorboard_size"]
233-
except KeyError:
234-
raise Exception(
235-
f"tensorboard_size metadata entry not present in"
236-
f" (existing keys: {set(model_array.meta.keys())})"
237-
)
238-
239-
tb_contents = model_array[0:tensorboard_size]["tensorboard"]
240-
tensorboard_files = pickle.loads(tb_contents.tobytes())
241-
242-
for path, file_bytes in tensorboard_files.items():
243-
log_dir = os.path.dirname(path)
244-
if not os.path.exists(log_dir):
245-
os.mkdir(log_dir)
246-
with open(os.path.join(log_dir, os.path.basename(path)), "wb") as f:
247-
f.write(file_bytes)
248-
249-
def _use_legacy_schema(self, timestamp: Optional[Timestamp]) -> bool:
217+
tensorboard_files = self._get_model_param(model_array, "tensorboard")
218+
for path, file_bytes in tensorboard_files.items():
219+
os.makedirs(os.path.dirname(path), exist_ok=True)
220+
with open(path, "wb") as f:
221+
f.write(file_bytes)
222+
223+
def _use_legacy_schema(self, model_array: tiledb.Array) -> bool:
250224
# TODO: Decide based on tiledb-ml version and not on schema characteristics, like "offset".
251-
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
252-
return str(model_array.schema.domain.dim(0).name) != "offset"
225+
return str(model_array.schema.domain.dim(0).name) != "offset"

tiledb/ml/models/pytorch.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,6 @@ def save(
6262
else:
6363
serialized_optimizer_dict = b""
6464

65-
# Serialize Tensorboard files
66-
if summary_writer:
67-
tensorboard = self._serialize_tensorboard_files(
68-
log_dir=summary_writer.log_dir
69-
)
70-
else:
71-
tensorboard = b""
72-
7365
# Create TileDB model array
7466
if not update:
7567
self._create_array(fields=["model", "optimizer", "tensorboard"])
@@ -78,13 +70,11 @@ def save(
7870
model_params={
7971
"model": serialized_model_dict,
8072
"optimizer": serialized_optimizer_dict,
81-
"tensorboard": tensorboard,
82-
}
73+
},
74+
tensorboard_log_dir=summary_writer.log_dir if summary_writer else None,
75+
meta=meta,
8376
)
8477

85-
if meta:
86-
self._write_model_metadata(meta=meta)
87-
8878
def load(
8979
self,
9080
*,
@@ -102,29 +92,19 @@ def load(
10292
:param callback: Boolean variable if True will store Callback data into saved directory
10393
:return: A dictionary with attributes other than model or optimizer state_dict.
10494
"""
105-
106-
load = (
107-
self.__load_legacy
108-
if self._use_legacy_schema(timestamp=timestamp)
109-
else self.__load
110-
)
111-
return load(
112-
model=model, optimizer=optimizer, timestamp=timestamp, callback=callback
113-
)
95+
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
96+
if self._use_legacy_schema(model_array):
97+
return self.__load_legacy(model_array, model, optimizer, callback)
98+
else:
99+
return self.__load(model_array, model, optimizer, callback)
114100

115101
def __load_legacy(
116102
self,
103+
model_array: tiledb.Array,
117104
model: torch.nn.Module,
118105
optimizer: Optimizer,
119-
timestamp: Optional[Timestamp],
120106
callback: bool,
121107
) -> Optional[Mapping[str, Any]]:
122-
"""
123-
Load a PyTorch model from a TileDB array.
124-
"""
125-
126-
# TODO: Change timestamp when issue in core is resolved
127-
model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
128108
model_array_results = model_array[:]
129109
schema = model_array.schema
130110

@@ -169,25 +149,16 @@ def __load_legacy(
169149

170150
def __load(
171151
self,
152+
model_array: tiledb.Array,
172153
model: torch.nn.Module,
173154
optimizer: Optimizer,
174-
timestamp: Optional[Timestamp],
175155
callback: bool,
176156
) -> None:
177-
"""
178-
Load a PyTorch model from a TileDB array.
179-
"""
180-
181-
model_state_dict = self.get_weights(timestamp=timestamp)
182-
model.load_state_dict(model_state_dict)
183-
184-
# Load model's state dictionary
157+
model.load_state_dict(self._get_model_param(model_array, "model"))
185158
if optimizer:
186-
opt_state_dict = self.get_optimizer_weights(timestamp=timestamp)
187-
optimizer.load_state_dict(opt_state_dict)
188-
159+
optimizer.load_state_dict(self._get_model_param(model_array, "optimizer"))
189160
if callback:
190-
self._load_tensorboard(timestamp=timestamp)
161+
self._load_tensorboard(model_array)
191162

192163
def preview(self) -> str:
193164
"""

tiledb/ml/models/sklearn.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def save(self, *, update: bool = False, meta: Optional[Meta] = None) -> None:
4848
if not update:
4949
self._create_array(fields=["model"])
5050

51-
self._write_array(model_params={"model": serialized_model})
52-
53-
if meta:
54-
self._write_model_metadata(meta=meta)
51+
self._write_array(model_params={"model": serialized_model}, meta=meta)
5552

5653
def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator:
5754
"""
@@ -62,42 +59,17 @@ def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator:
6259
in the specified time range.
6360
:return: A Sklearn model object.
6461
"""
65-
# TODO: Change timestamp when issue in core is resolved
66-
67-
load = (
68-
self.__load_legacy
69-
if self._use_legacy_schema(timestamp=timestamp)
70-
else self.__load
71-
)
72-
return load(timestamp=timestamp)
73-
74-
def __load_legacy(self, *, timestamp: Optional[Timestamp]) -> BaseEstimator:
75-
"""
76-
Load a Sklearn model from a TileDB array.
77-
"""
78-
model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
79-
model_array_results = model_array[:]
80-
model = pickle.loads(model_array_results["model_params"].item(0))
81-
return model
62+
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
63+
if self._use_legacy_schema(model_array):
64+
return self.__load_legacy(model_array)
65+
else:
66+
return self.__load(model_array)
8267

83-
def __load(self, *, timestamp: Optional[Timestamp]) -> BaseEstimator:
84-
"""
85-
Load a Sklearn model from a TileDB array.
86-
"""
68+
def __load_legacy(self, model_array: tiledb.Array) -> BaseEstimator:
69+
return pickle.loads(model_array[:]["model_params"].item(0))
8770

88-
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
89-
try:
90-
model_size = model_array.meta["model_size"]
91-
except KeyError:
92-
raise Exception(
93-
f"model_size metadata entry not present in {self.uri}"
94-
f" (existing keys: {set(model_array.meta.keys())})"
95-
)
96-
97-
model_contents = model_array[0:model_size]["model"]
98-
model_bytes = model_contents.tobytes()
99-
100-
return pickle.loads(model_bytes)
71+
def __load(self, model_array: tiledb.Array) -> BaseEstimator:
72+
return self._get_model_param(model_array, "model")
10173

10274
def preview(self, *, display: str = "text") -> str:
10375
"""

0 commit comments

Comments
 (0)