@@ -65,7 +65,14 @@ def __init__(
65
65
self .ctx = ctx
66
66
self .artifact = artifact
67
67
self .uri = get_cloud_uri (uri , namespace ) if namespace else uri
68
- self ._file_properties = self ._get_file_properties ()
68
+ self ._file_properties = {
69
+ ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK .value : self .Name ,
70
+ ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION .value : self .Version ,
71
+ ModelFileProperties .TILEDB_ML_MODEL_STAGE .value : "STAGING" ,
72
+ ModelFileProperties .TILEDB_ML_MODEL_PYTHON_VERSION .value : platform .python_version (),
73
+ ModelFileProperties .TILEDB_ML_MODEL_PREVIEW .value : self .preview (),
74
+ ModelFileProperties .TILEDB_ML_MODEL_VERSION .value : __version__ ,
75
+ }
69
76
70
77
@abstractmethod
71
78
def save (self , * , update : bool = False , meta : Optional [Meta ] = None ) -> None :
@@ -88,34 +95,23 @@ def get_weights(self, timestamp: Optional[Timestamp] = None) -> Weights:
88
95
"""
89
96
Returns model's weights. Works for Tensorflow Keras and PyTorch
90
97
"""
91
- return cast (Weights , self ._get_model_param ("model" , timestamp ))
98
+ with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
99
+ return cast (Weights , self ._get_model_param (model_array , "model" ))
92
100
93
101
def get_optimizer_weights (self , timestamp : Optional [Timestamp ] = None ) -> Weights :
94
102
"""
95
103
Returns optimizer's weights. Works for Tensorflow Keras and PyTorch
96
104
"""
97
- return cast (Weights , self ._get_model_param ("optimizer" , timestamp ))
105
+ with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
106
+ return cast (Weights , self ._get_model_param (model_array , "optimizer" ))
98
107
99
108
@abstractmethod
100
109
def preview (self ) -> str :
101
110
"""
102
111
Creates a string representation of a machine learning model.
103
112
"""
104
113
105
- def _get_file_properties (self ) -> Mapping [str , str ]:
106
- return {
107
- ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK .value : self .Name ,
108
- ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION .value : self .Version ,
109
- ModelFileProperties .TILEDB_ML_MODEL_STAGE .value : "STAGING" ,
110
- ModelFileProperties .TILEDB_ML_MODEL_PYTHON_VERSION .value : platform .python_version (),
111
- ModelFileProperties .TILEDB_ML_MODEL_PREVIEW .value : self .preview (),
112
- ModelFileProperties .TILEDB_ML_MODEL_VERSION .value : __version__ ,
113
- }
114
-
115
- def _create_array (
116
- self ,
117
- fields : Sequence [str ],
118
- ) -> None :
114
+ def _create_array (self , fields : Sequence [str ]) -> None :
119
115
"""Internal method that creates a TileDB array based on the model's spec."""
120
116
121
117
# The array will be be 1 dimensional with domain of 0 to max uint64. We use a tile extent of 1024 bytes
@@ -152,101 +148,78 @@ def _create_array(
152
148
if self .namespace :
153
149
update_file_properties (self .uri , self ._file_properties )
154
150
155
- def _write_array (self , model_params : Mapping [str , bytes ]) -> None :
156
- """
157
- Writes machine learning model related data, i.e., model weights, optimizer weights and Tensorboard files, to
158
- a dense TileDB array.
159
- """
151
+ def _write_array (
152
+ self ,
153
+ model_params : Mapping [str , bytes ],
154
+ tensorboard_log_dir : Optional [str ] = None ,
155
+ meta : Optional [Meta ] = None ,
156
+ ) -> None :
157
+ if tensorboard_log_dir :
158
+ tensorboard = self ._serialize_tensorboard (tensorboard_log_dir )
159
+ else :
160
+ tensorboard = b""
161
+ model_params = dict (tensorboard = tensorboard , ** model_params )
162
+
163
+ if meta is None :
164
+ meta = {}
165
+ if not meta .keys ().isdisjoint (self ._file_properties .keys ()):
166
+ raise ValueError (
167
+ "Please avoid using file property key names as metadata keys!"
168
+ )
160
169
161
170
with tiledb .open (self .uri , "w" , ctx = self .ctx ) as model_array :
162
171
one_d_buffers = {}
163
172
max_len = 0
164
-
165
173
for key , value in model_params .items ():
166
174
one_d_buffer = np .frombuffer (value , dtype = np .uint8 )
167
175
one_d_buffer_len = len (one_d_buffer )
168
176
one_d_buffers [key ] = one_d_buffer
169
-
170
177
# Write size only in case is greater than 0.
171
178
if one_d_buffer_len :
172
179
model_array .meta [key + "_size" ] = one_d_buffer_len
173
-
174
180
if one_d_buffer_len > max_len :
175
181
max_len = one_d_buffer_len
176
182
177
183
model_array [0 :max_len ] = {
178
184
key : np .pad (value , (0 , max_len - len (value )))
179
185
for key , value in one_d_buffers .items ()
180
186
}
181
-
182
- def _write_model_metadata (self , meta : Meta ) -> None :
183
- """
184
- Update the metadata in a TileDB model array. File properties also go in the metadata section.
185
- :param meta: A mapping with the <key, value> pairs to be inserted in array's metadata.
186
- """
187
- with tiledb .open (self .uri , "w" , ctx = self .ctx ) as model_array :
188
- # Raise ValueError in case users provide metadata with the same keys as file properties.
189
- if not meta .keys ().isdisjoint (self ._file_properties .keys ()):
190
- raise ValueError (
191
- "Please avoid using file property key names as metadata keys!"
192
- )
193
-
194
- for key , value in meta .items ():
195
- model_array .meta [key ] = value
196
-
197
- for key , value in self ._file_properties .items ():
198
- model_array .meta [key ] = value
187
+ for mapping in meta , self ._file_properties :
188
+ for key , value in mapping .items ():
189
+ model_array .meta [key ] = value
190
+
191
+ def _get_model_param (self , model_array : tiledb .Array , key : str ) -> Any :
192
+ size_key = key + "_size"
193
+ try :
194
+ size = model_array .meta [size_key ]
195
+ except KeyError :
196
+ raise Exception (
197
+ f"{ size_key } metadata entry not present in { self .uri } "
198
+ f" (existing keys: { set (model_array .meta .keys ())} )"
199
+ )
200
+ return pickle .loads (model_array .query (attrs = (key ,))[0 :size ][key ].tobytes ())
199
201
200
202
@staticmethod
201
- def _serialize_tensorboard_files (log_dir : str ) -> bytes :
203
+ def _serialize_tensorboard (log_dir : str ) -> bytes :
202
204
"""Serialize all Tensorboard files."""
203
-
204
205
if not os .path .exists (log_dir ):
205
206
raise ValueError (f"{ log_dir } does not exist" )
206
-
207
- event_files = {}
207
+ tensorboard_files = {}
208
208
for path in glob .glob (f"{ log_dir } /*tfevents*" ):
209
209
with open (path , "rb" ) as f :
210
- event_files [path ] = f .read ()
210
+ tensorboard_files [path ] = f .read ()
211
+ return pickle .dumps (tensorboard_files , protocol = 4 )
211
212
212
- return pickle .dumps (event_files , protocol = 4 )
213
-
214
- def _get_model_param (self , key : str , timestamp : Optional [Timestamp ]) -> Any :
215
- with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
216
- size_key = key + "_size"
217
- try :
218
- size = model_array .meta [size_key ]
219
- except KeyError :
220
- raise Exception (
221
- f"{ size_key } metadata entry not present in { self .uri } "
222
- f" (existing keys: { set (model_array .meta .keys ())} )"
223
- )
224
- return pickle .loads (model_array [0 :size ][key ].tobytes ())
225
-
226
- def _load_tensorboard (self , timestamp : Optional [Timestamp ] = None ) -> None :
213
+ def _load_tensorboard (self , model_array : tiledb .Array ) -> None :
227
214
"""
228
- Writes Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
215
+ Write Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
229
216
"""
230
- with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
231
- try :
232
- tensorboard_size = model_array .meta ["tensorboard_size" ]
233
- except KeyError :
234
- raise Exception (
235
- f"tensorboard_size metadata entry not present in"
236
- f" (existing keys: { set (model_array .meta .keys ())} )"
237
- )
238
-
239
- tb_contents = model_array [0 :tensorboard_size ]["tensorboard" ]
240
- tensorboard_files = pickle .loads (tb_contents .tobytes ())
241
-
242
- for path , file_bytes in tensorboard_files .items ():
243
- log_dir = os .path .dirname (path )
244
- if not os .path .exists (log_dir ):
245
- os .mkdir (log_dir )
246
- with open (os .path .join (log_dir , os .path .basename (path )), "wb" ) as f :
247
- f .write (file_bytes )
248
-
249
- def _use_legacy_schema (self , timestamp : Optional [Timestamp ]) -> bool :
217
+ tensorboard_files = self ._get_model_param (model_array , "tensorboard" )
218
+ for path , file_bytes in tensorboard_files .items ():
219
+ os .makedirs (os .path .dirname (path ), exist_ok = True )
220
+ with open (path , "wb" ) as f :
221
+ f .write (file_bytes )
222
+
223
+ def _use_legacy_schema (self , model_array : tiledb .Array ) -> bool :
250
224
# TODO: Decide based on tiledb-ml version and not on schema characteristics, like "offset".
251
- with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
252
- return str (model_array .schema .domain .dim (0 ).name ) != "offset"
225
+ return str (model_array .schema .domain .dim (0 ).name ) != "offset"
0 commit comments