1
1
"""Tests for TileDB PyTorch model save and load."""
2
2
3
+ import glob
3
4
import inspect
4
5
import os
5
6
import pickle
6
7
import platform
7
- import sys
8
+ import shutil
8
9
9
10
import pytest
10
11
import torch
@@ -87,23 +88,19 @@ def forward(self, x):
87
88
return x
88
89
89
90
90
- @pytest .mark .parametrize (
91
- "optimizer" ,
92
- [
93
- getattr (optimizers , name )
94
- for name , obj in inspect .getmembers (optimizers )
95
- if inspect .isclass (obj ) and name != "Optimizer"
96
- ],
97
- )
98
- @pytest .mark .parametrize (
99
- "net" ,
100
- [
101
- getattr (sys .modules [__name__ ], name )
102
- for name , obj in inspect .getmembers (sys .modules [__name__ ])
103
- if inspect .isclass (obj ) and obj .__module__ == __name__
104
- ],
105
- )
91
+ net = pytest .mark .parametrize ("net" , [ConvNet , Net , SeqNeuralNetwork ])
92
+
93
+
106
94
class TestPyTorchModel :
95
+ @net
96
+ @pytest .mark .parametrize (
97
+ "optimizer" ,
98
+ [
99
+ getattr (optimizers , name )
100
+ for name , obj in inspect .getmembers (optimizers )
101
+ if inspect .isclass (obj ) and name != "Optimizer"
102
+ ],
103
+ )
107
104
def test_save (self , tmpdir , net , optimizer ):
108
105
EPOCH = 5
109
106
LOSS = 0.4
@@ -139,7 +136,8 @@ def test_save(self, tmpdir, net, optimizer):
139
136
):
140
137
assert all ([a == b for a , b in zip (key_item_1 [1 ], key_item_2 [1 ])])
141
138
142
- def test_preview (self , tmpdir , net , optimizer ):
139
+ @net
140
+ def test_preview (self , tmpdir , net ):
143
141
# With model given as argument
144
142
model = net ()
145
143
tiledb_array = os .path .join (tmpdir , "model_array" )
@@ -148,7 +146,8 @@ def test_preview(self, tmpdir, net, optimizer):
148
146
tiledb_obj_none = PyTorchTileDBModel (uri = tiledb_array , model = None )
149
147
assert tiledb_obj_none .preview () == ""
150
148
151
- def test_file_properties (self , tmpdir , net , optimizer ):
149
+ @net
150
+ def test_file_properties (self , tmpdir , net ):
152
151
model = net ()
153
152
tiledb_array = os .path .join (tmpdir , "model_array" )
154
153
tiledb_obj = PyTorchTileDBModel (uri = tiledb_array , model = model )
@@ -165,38 +164,42 @@ def test_file_properties(self, tmpdir, net, optimizer):
165
164
)
166
165
assert tiledb_obj ._file_properties ["TILEDB_ML_MODEL_PREVIEW" ] == str (model )
167
166
168
- def test_tensorboard_callback_meta (self , tmpdir , net , optimizer , mocker ):
167
+ @net
168
+ def test_tensorboard_callback_meta (self , tmpdir , net ):
169
169
model = net ()
170
170
tiledb_array = os .path .join (tmpdir , "model_array" )
171
171
tiledb_obj = PyTorchTileDBModel (uri = tiledb_array , model = model )
172
172
173
- mocker .patch (
174
- "tiledb.ml.models.pytorch.PyTorchTileDBModel._get_tensorboard_files" ,
175
- return_value = {
176
- f"{ tmpdir } /event_file_name_1" : b"test_bytes_1" ,
177
- f"{ tmpdir } /event_file_name_2" : b"test_bytes_2" ,
178
- },
179
- )
173
+ # SummaryWriter creates file(s) under log_dir
174
+ log_dir = os .path .join (tmpdir , "logs" )
175
+ writer = SummaryWriter (log_dir = log_dir )
176
+ log_files = read_files (log_dir )
177
+ assert log_files
180
178
181
- writer = SummaryWriter ()
182
179
tiledb_obj .save (update = False , summary_writer = writer )
183
-
184
180
with tiledb .open (tiledb_array ) as A :
185
- assert len (pickle .loads (A .meta ["__TENSORBOARD__" ])) == 2
186
- assert pickle .loads (A .meta ["__TENSORBOARD__" ]) == {
187
- f"{ tmpdir } /event_file_name_1" : b"test_bytes_1" ,
188
- f"{ tmpdir } /event_file_name_2" : b"test_bytes_2" ,
189
- }
181
+ assert pickle .loads (A .meta ["__TENSORBOARD__" ]) == log_files
182
+ shutil .rmtree (log_dir )
190
183
191
184
# Loading the event data should create local files
192
185
tiledb_obj .load_tensorboard ()
193
- assert os . path . exists ( f" { tmpdir } /event_file_name_1" )
194
- assert os . path . exists ( f" { tmpdir } /event_file_name_2" )
186
+ new_log_files = read_files ( log_dir )
187
+ assert new_log_files == log_files
195
188
196
189
custom_dir = os .path .join (tmpdir , "custom_log" )
197
190
tiledb_obj .load_tensorboard (target_dir = custom_dir )
198
- assert os .path .exists (f"{ custom_dir } /event_file_name_1" )
199
- assert os .path .exists (f"{ custom_dir } /event_file_name_2" )
191
+ new_log_files = read_files (custom_dir )
192
+ assert len (new_log_files ) == len (log_files )
193
+ for new_file , old_file in zip (new_log_files .values (), log_files .values ()):
194
+ assert new_file == old_file
195
+
196
+
197
+ def read_files (dirpath ):
198
+ files = {}
199
+ for path in glob .glob (f"{ dirpath } /*" ):
200
+ with open (path , "rb" ) as f :
201
+ files [path ] = f .read ()
202
+ return files
200
203
201
204
202
205
class TestPyTorchModelCloud :
0 commit comments