11
11
"""
12
12
13
13
14
+ import contextlib
14
15
import os
15
16
from os .path import join as pjoin
16
17
import tempfile
@@ -74,7 +75,7 @@ def __getattribute__(self, name):
74
75
val = object .__getattribute__ (self , name )
75
76
if name == 'series' and val is None :
76
77
val = []
77
- with _db_nochange () as c :
78
+ with DB . readonly_cursor () as c :
78
79
c .execute ("SELECT * FROM series WHERE study = ?" , (self .uid , ))
79
80
cols = [el [0 ] for el in c .description ]
80
81
for row in c :
@@ -106,7 +107,7 @@ def __getattribute__(self, name):
106
107
val = object .__getattribute__ (self , name )
107
108
if name == 'storage_instances' and val is None :
108
109
val = []
109
- with _db_nochange () as c :
110
+ with DB . readonly_cursor () as c :
110
111
query = """SELECT *
111
112
FROM storage_instance
112
113
WHERE series = ?
@@ -227,7 +228,7 @@ def __init__(self, d):
227
228
def __getattribute__ (self , name ):
228
229
val = object .__getattribute__ (self , name )
229
230
if name == 'files' and val is None :
230
- with _db_nochange () as c :
231
+ with DB . readonly_cursor () as c :
231
232
query = """SELECT directory, name
232
233
FROM file
233
234
WHERE storage_instance = ?
@@ -241,34 +242,6 @@ def dicom(self):
241
242
return pydicom .read_file (self .files [0 ])
242
243
243
244
244
- class _db_nochange :
245
- """context guard for read-only database access"""
246
-
247
- def __enter__ (self ):
248
- self .c = DB .cursor ()
249
- return self .c
250
-
251
- def __exit__ (self , type , value , traceback ):
252
- if type is None :
253
- self .c .close ()
254
- DB .rollback ()
255
-
256
-
257
- class _db_change :
258
- """context guard for database access requiring a commit"""
259
-
260
- def __enter__ (self ):
261
- self .c = DB .cursor ()
262
- return self .c
263
-
264
- def __exit__ (self , type , value , traceback ):
265
- if type is None :
266
- self .c .close ()
267
- DB .commit ()
268
- else :
269
- DB .rollback ()
270
-
271
-
272
245
def _get_subdirs (base_dir , files_dict = None , followlinks = False ):
273
246
dirs = []
274
247
for (dirpath , dirnames , filenames ) in os .walk (base_dir , followlinks = followlinks ):
@@ -288,7 +261,7 @@ def update_cache(base_dir, followlinks=False):
288
261
for d in dirs :
289
262
os .stat (d )
290
263
mtimes [d ] = os .stat (d ).st_mtime
291
- with _db_nochange () as c :
264
+ with DB . readwrite_cursor () as c :
292
265
c .execute ("SELECT path, mtime FROM directory" )
293
266
db_mtimes = dict (c )
294
267
c .execute ("SELECT uid FROM study" )
@@ -297,7 +270,6 @@ def update_cache(base_dir, followlinks=False):
297
270
series = [row [0 ] for row in c ]
298
271
c .execute ("SELECT uid FROM storage_instance" )
299
272
storage_instances = [row [0 ] for row in c ]
300
- with _db_change () as c :
301
273
for dir in sorted (mtimes .keys ()):
302
274
if dir in db_mtimes and mtimes [dir ] <= db_mtimes [dir ]:
303
275
continue
@@ -316,7 +288,7 @@ def get_studies(base_dir=None, followlinks=False):
316
288
if base_dir is not None :
317
289
update_cache (base_dir , followlinks )
318
290
if base_dir is None :
319
- with _db_nochange () as c :
291
+ with DB . readonly_cursor () as c :
320
292
c .execute ("SELECT * FROM study" )
321
293
studies = []
322
294
cols = [el [0 ] for el in c .description ]
@@ -331,7 +303,7 @@ def get_studies(base_dir=None, followlinks=False):
331
303
WHERE uid IN (SELECT storage_instance
332
304
FROM file
333
305
WHERE directory = ?))"""
334
- with _db_nochange () as c :
306
+ with DB . readonly_cursor () as c :
335
307
study_uids = {}
336
308
for dir in _get_subdirs (base_dir , followlinks = followlinks ):
337
309
c .execute (query , (dir , ))
@@ -443,7 +415,7 @@ def _update_file(c, path, fname, studies, series, storage_instances):
443
415
444
416
445
417
def clear_cache ():
446
- with _db_change () as c :
418
+ with DB . readwrite_cursor () as c :
447
419
c .execute ("DELETE FROM file" )
448
420
c .execute ("DELETE FROM directory" )
449
421
c .execute ("DELETE FROM storage_instance" )
@@ -478,26 +450,64 @@ def clear_cache():
478
450
mtime INTEGER NOT NULL,
479
451
storage_instance TEXT DEFAULT NULL REFERENCES storage_instance,
480
452
PRIMARY KEY (directory, name))""" )
481
- DB_FNAME = pjoin (tempfile .gettempdir (), f'dft.{ getpass .getuser ()} .sqlite' )
482
- DB = None
483
453
484
454
485
- def _init_db (verbose = True ):
486
- """ Initialize database """
487
- if verbose :
488
- logger .info ('db filename: ' + DB_FNAME )
489
- global DB
490
- DB = sqlite3 .connect (DB_FNAME , check_same_thread = False )
491
- with _db_change () as c :
492
- c .execute ("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'" )
493
- if c .fetchone ()[0 ] == 0 :
494
- logger .debug ('create' )
495
- for q in CREATE_QUERIES :
496
- c .execute (q )
455
+ class _DB :
456
+ def __init__ (self , fname = None , verbose = True ):
457
+ self .fname = fname or pjoin (tempfile .gettempdir (), f'dft.{ getpass .getuser ()} .sqlite' )
458
+ self .verbose = verbose
459
+
460
+ @property
461
+ def session (self ):
462
+ """Get sqlite3 Connection
463
+
464
+ The connection is created on the first call of this property
465
+ """
466
+ try :
467
+ return self ._session
468
+ except AttributeError :
469
+ self ._init_db ()
470
+ return self ._session
471
+
472
+ def _init_db (self ):
473
+ if self .verbose :
474
+ logger .info ('db filename: ' + self .fname )
475
+
476
+ self ._session = sqlite3 .connect (self .fname , isolation_level = "EXCLUSIVE" )
477
+ with self .readwrite_cursor () as c :
478
+ c .execute ("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'" )
479
+ if c .fetchone ()[0 ] == 0 :
480
+ logger .debug ('create' )
481
+ for q in CREATE_QUERIES :
482
+ c .execute (q )
483
+
484
+ def __repr__ (self ):
485
+ return f"<DFT { self .fname !r} >"
486
+
487
+ @contextlib .contextmanager
488
+ def readonly_cursor (self ):
489
+ cursor = self .session .cursor ()
490
+ try :
491
+ yield cursor
492
+ finally :
493
+ cursor .close ()
494
+ self .session .rollback ()
495
+
496
+ @contextlib .contextmanager
497
+ def readwrite_cursor (self ):
498
+ cursor = self .session .cursor ()
499
+ try :
500
+ yield cursor
501
+ except Exception :
502
+ self .session .rollback ()
503
+ raise
504
+ finally :
505
+ cursor .close ()
506
+ self .session .commit ()
497
507
498
508
509
+ DB = None
499
510
if os .name == 'nt' :
500
511
warnings .warn ('dft needs FUSE which is not available for windows' )
501
512
else :
502
- _init_db ()
503
- # eof
513
+ DB = _DB ()
0 commit comments