Skip to content

Commit 2b6bb6a

Browse files
authored
Extract tensorboard callback logic in a separate module (#142)
1 parent d7e6155 commit 2b6bb6a

File tree

5 files changed

+101
-99
lines changed

5 files changed

+101
-99
lines changed

tests/models/test_pytorch_models.py

+41-38
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Tests for TileDB PyTorch model save and load."""
22

3+
import glob
34
import inspect
45
import os
56
import pickle
67
import platform
7-
import sys
8+
import shutil
89

910
import pytest
1011
import torch
@@ -87,23 +88,19 @@ def forward(self, x):
8788
return x
8889

8990

90-
@pytest.mark.parametrize(
91-
"optimizer",
92-
[
93-
getattr(optimizers, name)
94-
for name, obj in inspect.getmembers(optimizers)
95-
if inspect.isclass(obj) and name != "Optimizer"
96-
],
97-
)
98-
@pytest.mark.parametrize(
99-
"net",
100-
[
101-
getattr(sys.modules[__name__], name)
102-
for name, obj in inspect.getmembers(sys.modules[__name__])
103-
if inspect.isclass(obj) and obj.__module__ == __name__
104-
],
105-
)
91+
net = pytest.mark.parametrize("net", [ConvNet, Net, SeqNeuralNetwork])
92+
93+
10694
class TestPyTorchModel:
95+
@net
96+
@pytest.mark.parametrize(
97+
"optimizer",
98+
[
99+
getattr(optimizers, name)
100+
for name, obj in inspect.getmembers(optimizers)
101+
if inspect.isclass(obj) and name != "Optimizer"
102+
],
103+
)
107104
def test_save(self, tmpdir, net, optimizer):
108105
EPOCH = 5
109106
LOSS = 0.4
@@ -139,7 +136,8 @@ def test_save(self, tmpdir, net, optimizer):
139136
):
140137
assert all([a == b for a, b in zip(key_item_1[1], key_item_2[1])])
141138

142-
def test_preview(self, tmpdir, net, optimizer):
139+
@net
140+
def test_preview(self, tmpdir, net):
143141
# With model given as argument
144142
model = net()
145143
tiledb_array = os.path.join(tmpdir, "model_array")
@@ -148,7 +146,8 @@ def test_preview(self, tmpdir, net, optimizer):
148146
tiledb_obj_none = PyTorchTileDBModel(uri=tiledb_array, model=None)
149147
assert tiledb_obj_none.preview() == ""
150148

151-
def test_file_properties(self, tmpdir, net, optimizer):
149+
@net
150+
def test_file_properties(self, tmpdir, net):
152151
model = net()
153152
tiledb_array = os.path.join(tmpdir, "model_array")
154153
tiledb_obj = PyTorchTileDBModel(uri=tiledb_array, model=model)
@@ -165,38 +164,42 @@ def test_file_properties(self, tmpdir, net, optimizer):
165164
)
166165
assert tiledb_obj._file_properties["TILEDB_ML_MODEL_PREVIEW"] == str(model)
167166

168-
def test_tensorboard_callback_meta(self, tmpdir, net, optimizer, mocker):
167+
@net
168+
def test_tensorboard_callback_meta(self, tmpdir, net):
169169
model = net()
170170
tiledb_array = os.path.join(tmpdir, "model_array")
171171
tiledb_obj = PyTorchTileDBModel(uri=tiledb_array, model=model)
172172

173-
mocker.patch(
174-
"tiledb.ml.models.pytorch.PyTorchTileDBModel._get_tensorboard_files",
175-
return_value={
176-
f"{tmpdir}/event_file_name_1": b"test_bytes_1",
177-
f"{tmpdir}/event_file_name_2": b"test_bytes_2",
178-
},
179-
)
173+
# SummaryWriter creates file(s) under log_dir
174+
log_dir = os.path.join(tmpdir, "logs")
175+
writer = SummaryWriter(log_dir=log_dir)
176+
log_files = read_files(log_dir)
177+
assert log_files
180178

181-
writer = SummaryWriter()
182179
tiledb_obj.save(update=False, summary_writer=writer)
183-
184180
with tiledb.open(tiledb_array) as A:
185-
assert len(pickle.loads(A.meta["__TENSORBOARD__"])) == 2
186-
assert pickle.loads(A.meta["__TENSORBOARD__"]) == {
187-
f"{tmpdir}/event_file_name_1": b"test_bytes_1",
188-
f"{tmpdir}/event_file_name_2": b"test_bytes_2",
189-
}
181+
assert pickle.loads(A.meta["__TENSORBOARD__"]) == log_files
182+
shutil.rmtree(log_dir)
190183

191184
# Loading the event data should create local files
192185
tiledb_obj.load_tensorboard()
193-
assert os.path.exists(f"{tmpdir}/event_file_name_1")
194-
assert os.path.exists(f"{tmpdir}/event_file_name_2")
186+
new_log_files = read_files(log_dir)
187+
assert new_log_files == log_files
195188

196189
custom_dir = os.path.join(tmpdir, "custom_log")
197190
tiledb_obj.load_tensorboard(target_dir=custom_dir)
198-
assert os.path.exists(f"{custom_dir}/event_file_name_1")
199-
assert os.path.exists(f"{custom_dir}/event_file_name_2")
191+
new_log_files = read_files(custom_dir)
192+
assert len(new_log_files) == len(log_files)
193+
for new_file, old_file in zip(new_log_files.values(), log_files.values()):
194+
assert new_file == old_file
195+
196+
197+
def read_files(dirpath):
198+
files = {}
199+
for path in glob.glob(f"{dirpath}/*"):
200+
with open(path, "rb") as f:
201+
files[path] = f.read()
202+
return files
200203

201204

202205
class TestPyTorchModelCloud:

tests/models/test_tensorflow_keras_models.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import pickle
66
import platform
7+
import shutil
78

89
import numpy as np
910
import pytest
@@ -531,36 +532,38 @@ def test_exception_raise_file_property_in_meta_error(self, tmpdir):
531532
ex.value
532533
)
533534

534-
def test_tensorboard_callback_meta(self, tmpdir, mocker):
535+
def test_tensorboard_callback_meta(self, tmpdir):
535536
model = keras.models.Sequential()
536537
model.add(keras.layers.Flatten(input_shape=(10, 10)))
537538
tiledb_array = os.path.join(tmpdir, "model_array")
538539
tiledb_obj = TensorflowKerasTileDBModel(uri=tiledb_array, model=model)
539540

540541
cb = [tf.keras.callbacks.TensorBoard(log_dir=tmpdir)]
541542

542-
mocker.patch(
543-
"tiledb.ml.models.tensorflow_keras.TensorflowKerasTileDBModel._get_tensorboard_files",
544-
return_value={
545-
f"{tmpdir}/event_file_name_1": b"test_bytes_1",
546-
f"{tmpdir}/event_file_name_2": b"test_bytes_2",
547-
},
548-
)
543+
os.mkdir(os.path.join(tmpdir, "train"))
544+
with open(os.path.join(tmpdir, "train", "foo_tfevents_1"), "wb") as f:
545+
f.write(b"test_bytes_1")
546+
with open(os.path.join(tmpdir, "train", "bar_tfevents_2"), "wb") as f:
547+
f.write(b"test_bytes_2")
549548

550549
tiledb_obj.save(include_callbacks=cb)
551550
with tiledb.open(tiledb_array) as A:
552-
assert len(pickle.loads(A.meta["__TENSORBOARD__"])) == 2
553551
assert pickle.loads(A.meta["__TENSORBOARD__"]) == {
554-
f"{tmpdir}/event_file_name_1": b"test_bytes_1",
555-
f"{tmpdir}/event_file_name_2": b"test_bytes_2",
552+
os.path.join(tmpdir, "train", "foo_tfevents_1"): b"test_bytes_1",
553+
os.path.join(tmpdir, "train", "bar_tfevents_2"): b"test_bytes_2",
556554
}
555+
shutil.rmtree(os.path.join(tmpdir, "train"))
557556

558557
# Loading the event data should create local files
559558
tiledb_obj.load_tensorboard()
560-
assert os.path.exists(f"{tmpdir}/event_file_name_1")
561-
assert os.path.exists(f"{tmpdir}/event_file_name_2")
559+
with open(os.path.join(tmpdir, "train", "foo_tfevents_1"), "rb") as f:
560+
assert f.read() == b"test_bytes_1"
561+
with open(os.path.join(tmpdir, "train", "bar_tfevents_2"), "rb") as f:
562+
assert f.read() == b"test_bytes_2"
562563

563564
custom_dir = os.path.join(tmpdir, "custom_log")
564565
tiledb_obj.load_tensorboard(target_dir=custom_dir)
565-
assert os.path.exists(f"{custom_dir}/event_file_name_1")
566-
assert os.path.exists(f"{custom_dir}/event_file_name_2")
566+
with open(os.path.join(custom_dir, "foo_tfevents_1"), "rb") as f:
567+
assert f.read() == b"test_bytes_1"
568+
with open(os.path.join(custom_dir, "bar_tfevents_2"), "rb") as f:
569+
assert f.read() == b"test_bytes_2"

tiledb/ml/models/_tensorboard.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import glob
2+
import os
3+
import pickle
4+
from typing import Mapping, Optional
5+
6+
import tiledb
7+
8+
from .base import Timestamp
9+
10+
_KEY = "__TENSORBOARD__"
11+
12+
13+
def save_tensorboard(log_dir: str) -> Mapping[str, bytes]:
14+
event_files = {}
15+
for path in glob.glob(f"{log_dir}/*tfevents*"):
16+
with open(path, "rb") as f:
17+
event_files[path] = f.read()
18+
return {_KEY: pickle.dumps(event_files, protocol=4)}
19+
20+
21+
def load_tensorboard(
22+
uri: str,
23+
ctx: Optional[tiledb.Ctx] = None,
24+
target_dir: Optional[str] = None,
25+
timestamp: Optional[Timestamp] = None,
26+
) -> None:
27+
with tiledb.open(uri, ctx=ctx, timestamp=timestamp) as model_array:
28+
for path, file_bytes in pickle.loads(model_array.meta[_KEY]).items():
29+
log_dir = target_dir if target_dir else os.path.dirname(path)
30+
if not os.path.exists(log_dir):
31+
os.mkdir(log_dir)
32+
with open(os.path.join(log_dir, os.path.basename(path)), "wb") as f:
33+
f.write(file_bytes)

tiledb/ml/models/pytorch.py

+4-23
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Functionality for saving and loading PytTorch models as TileDB arrays"""
22

3-
import glob
4-
import os
53
import pickle
64
from typing import Any, Mapping, Optional
75

@@ -13,6 +11,7 @@
1311
import tiledb
1412

1513
from ._cloud_utils import update_file_properties
14+
from ._tensorboard import load_tensorboard, save_tensorboard
1615
from .base import Meta, TileDBModel, Timestamp, current_milli_time
1716

1817

@@ -81,11 +80,8 @@ def save(
8180

8281
# Summary writer
8382
if summary_writer:
84-
event_files = self._get_tensorboard_files(summary_writer.log_dir)
85-
meta = {
86-
"__TENSORBOARD__": pickle.dumps(event_files, protocol=4),
87-
**(meta or {}),
88-
}
83+
cb_meta = save_tensorboard(summary_writer.log_dir)
84+
meta = {**meta, **cb_meta} if meta else cb_meta
8985

9086
# Create TileDB model array
9187
if not update:
@@ -150,14 +146,7 @@ def load_tensorboard(
150146
target_dir: Optional[str] = None,
151147
timestamp: Optional[Timestamp] = None,
152148
) -> None:
153-
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
154-
tb_data = pickle.loads(model_array.meta["__TENSORBOARD__"])
155-
for path in tb_data.keys():
156-
log_dir = target_dir if target_dir else os.path.dirname(path)
157-
if not os.path.exists(log_dir):
158-
os.mkdir(log_dir)
159-
with open(os.path.join(log_dir, os.path.basename(path)), "wb") as f:
160-
f.write(tb_data[path])
149+
return load_tensorboard(self.uri, self.ctx, target_dir, timestamp)
161150

162151
def preview(self) -> str:
163152
"""
@@ -250,11 +239,3 @@ def _write_array(
250239
key: np.array([value]) for key, value in serialized_model_dict.items()
251240
}
252241
self.update_model_metadata(array=tf_model_tiledb, meta=meta)
253-
254-
@staticmethod
255-
def _get_tensorboard_files(log_dir: str) -> Mapping[str, bytes]:
256-
event_files = {}
257-
for path in glob.glob(f"{log_dir}/*tfevents*"):
258-
with open(path, "rb") as f:
259-
event_files[path] = f.read()
260-
return event_files

tiledb/ml/models/tensorflow_keras.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Functionality for saving and loading Tensorflow Keras models as TileDB arrays"""
22

33
import contextlib
4-
import glob
54
import io
65
import json
76
import logging
8-
import os.path
7+
import os
98
import pickle
109
from operator import attrgetter
1110
from typing import Any, List, Mapping, Optional, Tuple
@@ -25,6 +24,7 @@
2524
import tiledb
2625

2726
from ._cloud_utils import update_file_properties
27+
from ._tensorboard import load_tensorboard, save_tensorboard
2828
from .base import Meta, TileDBModel, Timestamp, current_milli_time
2929

3030
# SharedObjectLoadingScope was introduced in TensorFlow 2.5
@@ -73,11 +73,8 @@ def save(
7373
if include_callbacks:
7474
for cb in include_callbacks:
7575
if isinstance(cb, TensorBoard):
76-
event_files = self._get_tensorboard_files(cb.log_dir)
77-
meta = {
78-
"__TENSORBOARD__": pickle.dumps(event_files, protocol=4),
79-
**(meta or {}),
80-
}
76+
cb_meta = save_tensorboard(os.path.join(cb.log_dir, "train"))
77+
meta = {**meta, **cb_meta} if meta else cb_meta
8178

8279
# Create TileDB model array
8380
if not update:
@@ -182,14 +179,7 @@ def load_tensorboard(
182179
target_dir: Optional[str] = None,
183180
timestamp: Optional[Timestamp] = None,
184181
) -> None:
185-
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
186-
tb_data = pickle.loads(model_array.meta["__TENSORBOARD__"])
187-
for path in tb_data.keys():
188-
log_dir = target_dir if target_dir else os.path.dirname(path)
189-
if not os.path.exists(log_dir):
190-
os.mkdir(log_dir)
191-
with open(os.path.join(log_dir, os.path.basename(path)), "wb") as f:
192-
f.write(tb_data[path])
182+
return load_tensorboard(self.uri, self.ctx, target_dir, timestamp)
193183

194184
def preview(self) -> str:
195185
"""Create a string representation of the model."""
@@ -419,11 +409,3 @@ def _load_weights_from_tiledb(
419409
)
420410
var_value_tuples.extend(zip(weight_vars, read_weight_values))
421411
backend.batch_set_value(var_value_tuples)
422-
423-
@staticmethod
424-
def _get_tensorboard_files(log_dir: str) -> Mapping[str, bytes]:
425-
event_files = {}
426-
for path in glob.glob(f"{log_dir}/train/*tfevents*"):
427-
with open(path, "rb") as f:
428-
event_files[path] = f.read()
429-
return event_files

0 commit comments

Comments
 (0)