Skip to content

Commit b44a50c

Browse files
authored
[Fix/Enhancement] Adding default argument current time timestamp in open 'w' mode (#89)
* Adding in every save operation current time timestamp
1 parent fc781c9 commit b44a50c

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

tiledb/ml/models/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import platform
5+
import time
56
from abc import ABC, abstractmethod
67
from enum import Enum, unique
78
from typing import Any, Generic, Mapping, Optional, Tuple, TypeVar
@@ -26,6 +27,10 @@ class ModelFileProperties(Enum):
2627
TILEDB_ML_MODEL_PREVIEW = "TILEDB_ML_MODEL_PREVIEW"
2728

2829

30+
def current_milli_time() -> int:
31+
return round(time.time() * 1000)
32+
33+
2934
class TileDBModel(ABC, Generic[Model]):
3035
"""
3136
This is the base class for all TileDB model storage functionalities, i.e,

tiledb/ml/models/pytorch.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import tiledb
1111

12-
from .base import Meta, TileDBModel, Timestamp
12+
from .base import Meta, TileDBModel, Timestamp, current_milli_time
1313

1414

1515
class PyTorchTileDBModel(TileDBModel[torch.nn.Module]):
@@ -102,6 +102,8 @@ def load( # type: ignore
102102
:param optimizer: A defined PyTorch optimizer.
103103
:return: A dictionary with attributes other than model or optimizer state_dict.
104104
"""
105+
106+
# TODO: Change timestamp when issue in core is resolved
105107
model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
106108
model_array_results = model_array[:]
107109
schema = model_array.schema
@@ -211,7 +213,11 @@ def _write_array(
211213
optimizer state, extra model information) of a PyTorch model.
212214
:param meta: Extra metadata to save in a TileDB array.
213215
"""
214-
with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb:
216+
217+
# TODO: Change timestamp when issue in core is resolved
218+
with tiledb.open(
219+
self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx
220+
) as tf_model_tiledb:
215221
# Insertion in TileDB array
216222
tf_model_tiledb[:] = {
217223
key: np.array([value]) for key, value in serialized_model_dict.items()

tiledb/ml/models/sklearn.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import tiledb
1212

13-
from .base import Meta, TileDBModel, Timestamp
13+
from .base import Meta, TileDBModel, Timestamp, current_milli_time
1414

1515

1616
class SklearnTileDBModel(TileDBModel[BaseEstimator]):
@@ -47,6 +47,8 @@ def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator:
4747
in the specified time range.
4848
:return: A Sklearn model object.
4949
"""
50+
# TODO: Change timestamp when issue in core is resolved
51+
5052
model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
5153
model_array_results = model_array[:]
5254
model = pickle.loads(model_array_results["model_params"].item(0))
@@ -102,7 +104,11 @@ def _write_array(self, serialized_model: bytes, meta: Optional[Meta]) -> None:
102104
:param serialized_model: A pickled sklearn model.
103105
:param meta: Extra metadata to save in a TileDB array.
104106
"""
105-
with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb:
107+
# TODO: Change timestamp when issue in core is resolved
108+
109+
with tiledb.open(
110+
self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx
111+
) as tf_model_tiledb:
106112
# Insertion in TileDB array
107113
tf_model_tiledb[:] = {"model_params": np.array([serialized_model])}
108114
self.update_model_metadata(array=tf_model_tiledb, meta=meta)

tiledb/ml/models/tensorflow_keras.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import tiledb
2222

23-
from .base import Meta, TileDBModel, Timestamp
23+
from .base import Meta, TileDBModel, Timestamp, current_milli_time
2424

2525

2626
class TensorflowKerasTileDBModel(TileDBModel[tf.keras.Model]):
@@ -88,6 +88,8 @@ def load(
8888
:param input_shape: The shape that the custom model expects as input
8989
:return: Tensorflow model.
9090
"""
91+
# TODO: Change timestamp when issue in core is resolved
92+
9193
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
9294
model_array_results = model_array[:]
9395
model_config = json.loads(model_array.meta["model_config"])
@@ -257,7 +259,10 @@ def _write_array(
257259
) -> None:
258260
"""Write Tensorflow model to a TileDB array."""
259261
assert self.model
260-
with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb:
262+
# TODO: Change timestamp when issue in core is resolved
263+
with tiledb.open(
264+
self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx
265+
) as tf_model_tiledb:
261266
if isinstance(self.model, (Functional, Sequential)):
262267
tf_model_tiledb[:] = {
263268
"model_weights": np.array([serialized_weights]),

0 commit comments

Comments
 (0)