Skip to content

Commit 943e5e8

Browse files
authored
Merge pull request #1158 from effigies/fix/xdist-safe-dft
RF: Write DFT database manager as object
2 parents 8cf190d + 32bc89a commit 943e5e8

File tree

2 files changed

+113
-75
lines changed

2 files changed

+113
-75
lines changed

nibabel/dft.py

+62-52
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313

14+
import contextlib
1415
import os
1516
from os.path import join as pjoin
1617
import tempfile
@@ -74,7 +75,7 @@ def __getattribute__(self, name):
7475
val = object.__getattribute__(self, name)
7576
if name == 'series' and val is None:
7677
val = []
77-
with _db_nochange() as c:
78+
with DB.readonly_cursor() as c:
7879
c.execute("SELECT * FROM series WHERE study = ?", (self.uid, ))
7980
cols = [el[0] for el in c.description]
8081
for row in c:
@@ -106,7 +107,7 @@ def __getattribute__(self, name):
106107
val = object.__getattribute__(self, name)
107108
if name == 'storage_instances' and val is None:
108109
val = []
109-
with _db_nochange() as c:
110+
with DB.readonly_cursor() as c:
110111
query = """SELECT *
111112
FROM storage_instance
112113
WHERE series = ?
@@ -227,7 +228,7 @@ def __init__(self, d):
227228
def __getattribute__(self, name):
228229
val = object.__getattribute__(self, name)
229230
if name == 'files' and val is None:
230-
with _db_nochange() as c:
231+
with DB.readonly_cursor() as c:
231232
query = """SELECT directory, name
232233
FROM file
233234
WHERE storage_instance = ?
@@ -241,34 +242,6 @@ def dicom(self):
241242
return pydicom.read_file(self.files[0])
242243

243244

244-
class _db_nochange:
245-
"""context guard for read-only database access"""
246-
247-
def __enter__(self):
248-
self.c = DB.cursor()
249-
return self.c
250-
251-
def __exit__(self, type, value, traceback):
252-
if type is None:
253-
self.c.close()
254-
DB.rollback()
255-
256-
257-
class _db_change:
258-
"""context guard for database access requiring a commit"""
259-
260-
def __enter__(self):
261-
self.c = DB.cursor()
262-
return self.c
263-
264-
def __exit__(self, type, value, traceback):
265-
if type is None:
266-
self.c.close()
267-
DB.commit()
268-
else:
269-
DB.rollback()
270-
271-
272245
def _get_subdirs(base_dir, files_dict=None, followlinks=False):
273246
dirs = []
274247
for (dirpath, dirnames, filenames) in os.walk(base_dir, followlinks=followlinks):
@@ -288,7 +261,7 @@ def update_cache(base_dir, followlinks=False):
288261
for d in dirs:
289262
os.stat(d)
290263
mtimes[d] = os.stat(d).st_mtime
291-
with _db_nochange() as c:
264+
with DB.readwrite_cursor() as c:
292265
c.execute("SELECT path, mtime FROM directory")
293266
db_mtimes = dict(c)
294267
c.execute("SELECT uid FROM study")
@@ -297,7 +270,6 @@ def update_cache(base_dir, followlinks=False):
297270
series = [row[0] for row in c]
298271
c.execute("SELECT uid FROM storage_instance")
299272
storage_instances = [row[0] for row in c]
300-
with _db_change() as c:
301273
for dir in sorted(mtimes.keys()):
302274
if dir in db_mtimes and mtimes[dir] <= db_mtimes[dir]:
303275
continue
@@ -316,7 +288,7 @@ def get_studies(base_dir=None, followlinks=False):
316288
if base_dir is not None:
317289
update_cache(base_dir, followlinks)
318290
if base_dir is None:
319-
with _db_nochange() as c:
291+
with DB.readonly_cursor() as c:
320292
c.execute("SELECT * FROM study")
321293
studies = []
322294
cols = [el[0] for el in c.description]
@@ -331,7 +303,7 @@ def get_studies(base_dir=None, followlinks=False):
331303
WHERE uid IN (SELECT storage_instance
332304
FROM file
333305
WHERE directory = ?))"""
334-
with _db_nochange() as c:
306+
with DB.readonly_cursor() as c:
335307
study_uids = {}
336308
for dir in _get_subdirs(base_dir, followlinks=followlinks):
337309
c.execute(query, (dir, ))
@@ -443,7 +415,7 @@ def _update_file(c, path, fname, studies, series, storage_instances):
443415

444416

445417
def clear_cache():
446-
with _db_change() as c:
418+
with DB.readwrite_cursor() as c:
447419
c.execute("DELETE FROM file")
448420
c.execute("DELETE FROM directory")
449421
c.execute("DELETE FROM storage_instance")
@@ -478,26 +450,64 @@ def clear_cache():
478450
mtime INTEGER NOT NULL,
479451
storage_instance TEXT DEFAULT NULL REFERENCES storage_instance,
480452
PRIMARY KEY (directory, name))""")
481-
DB_FNAME = pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite')
482-
DB = None
483453

484454

485-
def _init_db(verbose=True):
486-
""" Initialize database """
487-
if verbose:
488-
logger.info('db filename: ' + DB_FNAME)
489-
global DB
490-
DB = sqlite3.connect(DB_FNAME, check_same_thread=False)
491-
with _db_change() as c:
492-
c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'")
493-
if c.fetchone()[0] == 0:
494-
logger.debug('create')
495-
for q in CREATE_QUERIES:
496-
c.execute(q)
455+
class _DB:
456+
def __init__(self, fname=None, verbose=True):
457+
self.fname = fname or pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite')
458+
self.verbose = verbose
459+
460+
@property
461+
def session(self):
462+
"""Get sqlite3 Connection
463+
464+
The connection is created on the first call of this property
465+
"""
466+
try:
467+
return self._session
468+
except AttributeError:
469+
self._init_db()
470+
return self._session
471+
472+
def _init_db(self):
473+
if self.verbose:
474+
logger.info('db filename: ' + self.fname)
475+
476+
self._session = sqlite3.connect(self.fname, isolation_level="EXCLUSIVE")
477+
with self.readwrite_cursor() as c:
478+
c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'")
479+
if c.fetchone()[0] == 0:
480+
logger.debug('create')
481+
for q in CREATE_QUERIES:
482+
c.execute(q)
483+
484+
def __repr__(self):
485+
return f"<DFT {self.fname!r}>"
486+
487+
@contextlib.contextmanager
488+
def readonly_cursor(self):
489+
cursor = self.session.cursor()
490+
try:
491+
yield cursor
492+
finally:
493+
cursor.close()
494+
self.session.rollback()
495+
496+
@contextlib.contextmanager
497+
def readwrite_cursor(self):
498+
cursor = self.session.cursor()
499+
try:
500+
yield cursor
501+
except Exception:
502+
self.session.rollback()
503+
raise
504+
finally:
505+
cursor.close()
506+
self.session.commit()
497507

498508

509+
DB = None
499510
if os.name == 'nt':
500511
warnings.warn('dft needs FUSE which is not available for windows')
501512
else:
502-
_init_db()
503-
# eof
513+
DB = _DB()

nibabel/tests/test_dft.py

+51-23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from os.path import join as pjoin, dirname
66
from io import BytesIO
77
from ..testing import suppress_warnings
8+
import sqlite3
89

910
with suppress_warnings():
1011
from .. import dft
@@ -29,26 +30,57 @@ def setUpModule():
2930
raise unittest.SkipTest('Need pydicom for dft tests, skipping')
3031

3132

32-
def test_init():
33+
class Test_DBclass:
34+
"""Some tests on the database manager class that don't get exercised through the API"""
35+
def setup_method(self):
36+
self._db = dft._DB(fname=":memory:", verbose=False)
37+
38+
def test_repr(self):
39+
assert repr(self._db) == "<DFT ':memory:'>"
40+
41+
def test_cursor_conflict(self):
42+
rwc = self._db.readwrite_cursor
43+
statement = ("INSERT INTO directory (path, mtime) VALUES (?, ?)", ("/tmp", 0))
44+
with pytest.raises(sqlite3.IntegrityError):
45+
# Whichever exits first will commit and make the second violate uniqueness
46+
with rwc() as c1, rwc() as c2:
47+
c1.execute(*statement)
48+
c2.execute(*statement)
49+
50+
51+
@pytest.fixture
52+
def db(monkeypatch):
53+
"""Build a dft database in memory to avoid cross-process races
54+
and not modify the host filesystem."""
55+
database = dft._DB(fname=":memory:")
56+
monkeypatch.setattr(dft, "DB", database)
57+
yield database
58+
59+
60+
def test_init(db):
3361
dft.clear_cache()
3462
dft.update_cache(data_dir)
63+
# Verify a second update doesn't crash
64+
dft.update_cache(data_dir)
3565

3666

37-
def test_study():
38-
studies = dft.get_studies(data_dir)
39-
assert len(studies) == 1
40-
assert (studies[0].uid ==
41-
'1.3.12.2.1107.5.2.32.35119.30000010011408520750000000022')
42-
assert studies[0].date == '20100114'
43-
assert studies[0].time == '121314.000000'
44-
assert studies[0].comments == 'dft study comments'
45-
assert studies[0].patient_name == 'dft patient name'
46-
assert studies[0].patient_id == '1234'
47-
assert studies[0].patient_birth_date == '19800102'
48-
assert studies[0].patient_sex == 'F'
49-
50-
51-
def test_series():
67+
def test_study(db):
68+
# First pass updates the cache, second pass reads it out
69+
for base_dir in (data_dir, None):
70+
studies = dft.get_studies(base_dir)
71+
assert len(studies) == 1
72+
assert (studies[0].uid ==
73+
'1.3.12.2.1107.5.2.32.35119.30000010011408520750000000022')
74+
assert studies[0].date == '20100114'
75+
assert studies[0].time == '121314.000000'
76+
assert studies[0].comments == 'dft study comments'
77+
assert studies[0].patient_name == 'dft patient name'
78+
assert studies[0].patient_id == '1234'
79+
assert studies[0].patient_birth_date == '19800102'
80+
assert studies[0].patient_sex == 'F'
81+
82+
83+
def test_series(db):
5284
studies = dft.get_studies(data_dir)
5385
assert len(studies[0].series) == 1
5486
ser = studies[0].series[0]
@@ -62,7 +94,7 @@ def test_series():
6294
assert ser.bits_stored == 12
6395

6496

65-
def test_storage_instances():
97+
def test_storage_instances(db):
6698
studies = dft.get_studies(data_dir)
6799
sis = studies[0].series[0].storage_instances
68100
assert len(sis) == 2
@@ -74,19 +106,15 @@ def test_storage_instances():
74106
'1.3.12.2.1107.5.2.32.35119.2010011420300180088599504.1')
75107

76108

77-
def test_storage_instance():
78-
pass
79-
80-
81109
@unittest.skipUnless(have_pil, 'could not import PIL.Image')
82-
def test_png():
110+
def test_png(db):
83111
studies = dft.get_studies(data_dir)
84112
data = studies[0].series[0].as_png()
85113
im = PImage.open(BytesIO(data))
86114
assert im.size == (256, 256)
87115

88116

89-
def test_nifti():
117+
def test_nifti(db):
90118
studies = dft.get_studies(data_dir)
91119
data = studies[0].series[0].as_nifti()
92120
assert len(data) == 352 + 2 * 256 * 256 * 2

0 commit comments

Comments
 (0)