Skip to content

Commit 45f50b2

Browse files
authored
FormatNXMX: Avoid storing duplicate static masks (#789)
When trying to dials.import a large number of experiment files, the RAM usage on Linux was unnecessarily high due to separate static masks being created for each file. Since masks tend to be limited in number for each experiment, it makes more sense to store masks with unique values only, which is what this PR aims to do. Closes dials/dials#2227.
1 parent 62e3332 commit 45f50b2

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

AUTHORS

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ Robert Oeffner
2929
Takanori Nakane
3030
Tara Michels-Clark
3131
Viktor Bengtsson
32-
32+
Yash Karan

newsfragments/789.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
``dials.import``: Reduce excessive memory usage when importing many (>100s) FormatNXMX files.

src/dxtbx/format/FormatNXmx.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,41 @@
11
from __future__ import annotations
22

3+
import weakref
4+
35
import h5py
46
import nxmx
57

8+
import scitbx.array_family.flex as flex
9+
610
import dxtbx.nexus
711
from dxtbx.format.FormatNexus import FormatNexus
812

913

14+
class _MaskCache:
15+
"""A singleton to hold unique static_mask objects to avoid duplications"""
16+
17+
def __init__(self):
18+
self.local_mask_cache = weakref.WeakValueDictionary()
19+
20+
def _mask_hasher(self, mask: flex.bool) -> int:
21+
return hash(mask.as_numpy_array().tobytes())
22+
23+
def store_unique_and_get(
24+
self, mask_tuple: tuple[flex.bool, ...] | None
25+
) -> tuple[flex.bool, ...] | None:
26+
if mask_tuple is None:
27+
return None
28+
output = []
29+
for mask in mask_tuple:
30+
mask_hash = self._mask_hasher(mask)
31+
mask = self.local_mask_cache.setdefault(mask_hash, mask)
32+
output.append(mask)
33+
return tuple(output)
34+
35+
36+
mask_cache = _MaskCache()
37+
38+
1039
def detector_between_sample_and_source(detector, beam):
1140
"""Check if the detector is perpendicular to beam and
1241
upstream of the sample."""
@@ -83,7 +112,9 @@ def _start(self):
83112
self._detector_model = inverted_distance_detector(self._detector_model)
84113

85114
self._scan_model = dxtbx.nexus.get_dxtbx_scan(nxsample, nxdetector)
86-
self._static_mask = dxtbx.nexus.get_static_mask(nxdetector)
115+
self._static_mask = mask_cache.store_unique_and_get(
116+
dxtbx.nexus.get_static_mask(nxdetector)
117+
)
87118
self._bit_depth_readout = nxdetector.bit_depth_readout
88119

89120
if self._scan_model:

0 commit comments

Comments
 (0)