|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import weakref |
| 4 | + |
3 | 5 | import h5py
|
4 | 6 | import nxmx
|
5 | 7 |
|
| 8 | +import scitbx.array_family.flex as flex |
| 9 | + |
6 | 10 | import dxtbx.nexus
|
7 | 11 | from dxtbx.format.FormatNexus import FormatNexus
|
8 | 12 |
|
9 | 13 |
|
10 |
| -# A singleton to hold unique static_mask objects to avoid duplications |
11 |
| -class MaskDict(dict): |
12 |
| - def hash_func(self, mask): |
13 |
| - return hash(mask[0].as_numpy_array().tobytes()) |
| 14 | +class _MaskCache: |
| 15 | + """A singleton to hold unique static_mask objects to avoid duplications""" |
| 16 | + |
| 17 | + def __init__(self): |
| 18 | + self.local_mask_cache = weakref.WeakValueDictionary() |
14 | 19 |
|
15 |
| - def insert(self, mask): |
16 |
| - mask_hash = self.hash_func(mask) |
17 |
| - if mask_hash not in self: |
18 |
| - mask_dict[mask_hash] = mask |
19 |
| - return mask_dict[mask_hash] |
| 20 | + def _mask_hasher(self, mask: flex.bool) -> int: |
| 21 | + return hash(mask.as_numpy_array().tobytes()) |
20 | 22 |
|
| 23 | + def store_unique_and_get( |
| 24 | + self, mask_tuple: tuple[flex.bool, ...] | None |
| 25 | + ) -> tuple[flex.bool, ...] | None: |
| 26 | + if mask_tuple is None: |
| 27 | + return None |
| 28 | + output = [] |
| 29 | + for mask in mask_tuple: |
| 30 | + mask_hash = self._mask_hasher(mask) |
| 31 | + self.local_mask_cache[mask_hash] = mask |
| 32 | + output.append(mask) |
| 33 | + return tuple(output) |
21 | 34 |
|
22 |
| -mask_dict = MaskDict() |
| 35 | + |
| 36 | +mask_cache = _MaskCache() |
23 | 37 |
|
24 | 38 |
|
25 | 39 | def detector_between_sample_and_source(detector, beam):
|
@@ -98,7 +112,9 @@ def _start(self):
|
98 | 112 | self._detector_model = inverted_distance_detector(self._detector_model)
|
99 | 113 |
|
100 | 114 | self._scan_model = dxtbx.nexus.get_dxtbx_scan(nxsample, nxdetector)
|
101 |
| - self._static_mask = mask_dict.insert(dxtbx.nexus.get_static_mask(nxdetector)) |
| 115 | + self._static_mask = mask_cache.store_unique_and_get( |
| 116 | + dxtbx.nexus.get_static_mask(nxdetector) |
| 117 | + ) |
102 | 118 | self._bit_depth_readout = nxdetector.bit_depth_readout
|
103 | 119 |
|
104 | 120 | if self._scan_model:
|
|
0 commit comments