Skip to content

Commit b35bccc

Browse files
Saturation speedup (#331)
2 parents 651015d + 5cecab4 commit b35bccc

File tree

2 files changed

+102
-33
lines changed

2 files changed

+102
-33
lines changed

changes/331.general.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Performance improvements for saturation step

src/stcal/saturation/saturation.py

Lines changed: 101 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -83,64 +83,85 @@ def flag_saturated_pixels(
8383

8484
for ints in range(nints):
8585
# Work forward through the groups for initial pass at saturation
86+
87+
# We want to flag saturation in all subsequent groups after
88+
# the one in which it was found. Use this boolean array to
89+
# keep a running tally of pixels that have saturated.
90+
previously_saturated = np.zeros(shape=(nrows, ncols), dtype='bool')
91+
8692
for group in range(ngroups):
8793
plane = data[ints, group, :, :]
8894

89-
flagarray, flaglowarray = plane_saturation(plane, sat_thresh, dqflags)
90-
9195
# for saturation, the flag is set in the current plane
9296
# and all following planes.
93-
np.bitwise_or(gdq[ints, group:, :, :], flagarray, gdq[ints, group:, :, :])
97+
98+
# Update the running tally of all pixels that have ever
99+
# experienced saturation to account for this.
100+
101+
previously_saturated |= (plane >= sat_thresh)
102+
flagarray = (previously_saturated * saturated).astype(np.uint32)
103+
104+
gdq[ints, group, :, :] |= flagarray
94105

95106
# for A/D floor, the flag is only set of the current plane
96-
np.bitwise_or(gdq[ints, group, :, :], flaglowarray, gdq[ints, group, :, :])
107+
flaglowarray = ((plane <= 0)*(ad_floor | dnu)).astype(np.uint32)
108+
109+
gdq[ints, group, :, :] |= flaglowarray
97110

98111
del flagarray
99112
del flaglowarray
100113

101114
# now, flag any pixels that border saturated pixels (not A/D floor pix)
102115
if n_pix_grow_sat > 0:
103-
gdq_slice = copy.copy(gdq[ints, group, :, :]).astype(int)
104-
105-
gdq[ints, group, :, :] = adjacent_pixels(gdq_slice, saturated, n_pix_grow_sat)
116+
gdq_slice = gdq[ints, group, :, :]
117+
adjacent_pixels(gdq_slice, saturated, n_pix_grow_sat, inplace=True)
106118

107119
# Work backward through the groups for a second pass at saturation
108120
# This is to flag things that actually saturated in prior groups but
109121
# were not obvious because of group averaging
110-
for group in range(ngroups-2, -1, -1):
122+
123+
for group in range(ngroups - 2, -1, -1):
124+
111125
plane = data[ints, group, :, :]
112126
thisdq = gdq[ints, group, :, :]
113127
nextdq = gdq[ints, group + 1, :, :]
114128

115129
# Determine the dilution factor due to group averaging
130+
131+
# No point in this step if the dilution factor is 1. In
132+
# that case, there is no way that we would have missed
133+
# saturation before but flag it now, since the threshold
134+
# would be the same.
135+
116136
if read_pattern is not None:
117137
# Single value dilution factor for this group
118138
dilution_factor = np.mean(read_pattern[group]) / read_pattern[group][-1]
139+
if dilution_factor == 1:
140+
continue
119141
# Broadcast to array size
120142
dilution_factor = np.where(no_sat_check_mask, 1, dilution_factor)
121143
else:
122144
dilution_factor = 1
145+
continue
123146

124-
# Find where this plane looks like it might saturate given the dilution factor
125-
flagarray, _ = plane_saturation(plane, sat_thresh * dilution_factor, dqflags)
147+
# Find where this plane looks like it might saturate given
148+
# the dilution factor, *and* this group did not already get
149+
# flagged as saturated or do not use, *and* the next group
150+
# was flagged as saturated. Result of the line below is a
151+
# boolean array.
126152

127-
# Find the overlap of where this plane looks like it might saturate, was not currently
128-
# flagged as saturation or DO_NOT_USE, and the next group had saturation flagged.
129-
indx = np.where((np.bitwise_and(flagarray, saturated) != 0) & \
130-
(np.bitwise_and(thisdq, saturated) == 0) & \
131-
(np.bitwise_and(thisdq, dnu) == 0) & \
132-
(np.bitwise_and(nextdq, saturated) != 0))
153+
partial_sat = ((plane >= sat_thresh*dilution_factor) & \
154+
(thisdq & (saturated | dnu) == 0) & \
155+
(nextdq & saturated != 0))
133156

134-
# Reset flag array to only pixels passing this gauntlet
135-
flagarray[:] = 0
136-
flagarray[indx] = dnu
157+
flagarray = (partial_sat * dnu).astype(np.uint32)
137158

138159
# Grow the newly-flagged saturating pixels
139160
if n_pix_grow_sat > 0:
140-
flagarray = adjacent_pixels(flagarray, dnu, n_pix_grow_sat)
161+
adjacent_pixels(flagarray, dnu, n_pix_grow_sat, inplace=True)
141162

142163
# Add them to the gdq array
143-
np.bitwise_or(gdq[ints, group, :, :], flagarray, gdq[ints, group, :, :])
164+
gdq[ints, group, :, :] |= flagarray
144165

145166
# Add an additional pass to look for things saturating in the second group
146167
# that can be particularly tricky to identify
@@ -160,25 +181,24 @@ def flag_saturated_pixels(
160181
mask &= scigp2 > sat_thresh / len(read_pattern[1])
161182

162183
# Identify groups that are saturated in the third group but not yet flagged in the second
163-
gp3mask = np.where((np.bitwise_and(dq3, saturated) != 0) & \
164-
(np.bitwise_and(dq2, saturated) == 0), True, False)
184+
gp3mask = ((np.bitwise_and(dq3, saturated) != 0) & \
185+
(np.bitwise_and(dq2, saturated) == 0))
165186
mask &= gp3mask
166187

167188
# Flag the 2nd group for the pixels passing that gauntlet
168-
flagarray = np.zeros_like(mask,dtype='uint8')
169-
flagarray[mask] = dnu
189+
flagarray = (mask * dnu).astype(np.uint32)
170190

171191
# Add them to the gdq array
172192
np.bitwise_or(gdq[ints, 1, :, :], flagarray, gdq[ints, 1, :, :])
173-
193+
174194

175195
# Check ZEROFRAME.
176196
if zframe is not None:
177197
plane = zframe[ints, :, :]
178198
flagarray, flaglowarray = plane_saturation(plane, sat_thresh, dqflags)
179199
zdq = flagarray | flaglowarray
180200
if n_pix_grow_sat > 0:
181-
zdq = adjacent_pixels(zdq, saturated, n_pix_grow_sat)
201+
adjacent_pixels(zdq, saturated, n_pix_grow_sat, inplace=True)
182202
plane[zdq != 0] = 0.0
183203
zframe[ints] = plane
184204

@@ -192,7 +212,7 @@ def flag_saturated_pixels(
192212
return gdq, pdq, zframe
193213

194214

195-
def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat):
215+
def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat=1, inplace=False):
196216
"""
197217
plane_gdq : ndarray
198218
The data quality flags of the current.
@@ -204,17 +224,65 @@ def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat):
204224
Number of pixels that each flagged saturated pixel should be 'grown',
205225
to account for charge spilling. Default is 1.
206226
227+
inplace : bool
228+
Update plane_gdq in place, returning None? Default False.
229+
207230
Return
208231
------
209232
sat_pix : ndarray
210233
The saturated pixels in the current plane.
211234
"""
212-
cgdq = plane_gdq.copy()
213-
only_sat = np.bitwise_and(plane_gdq, saturated).astype(np.uint8)
235+
if not inplace:
236+
cgdq = plane_gdq.copy()
237+
else:
238+
cgdq = plane_gdq
239+
240+
only_sat = plane_gdq & saturated > 0
241+
dilated = only_sat.copy()
214242
box_dim = (n_pix_grow_sat * 2) + 1
215-
struct = np.ones((box_dim, box_dim)).astype(bool)
216-
dialated = ndimage.binary_dilation(only_sat, structure=struct).astype(only_sat.dtype)
217-
return np.bitwise_or(cgdq, (dialated * saturated))
243+
244+
# The for loops below are equivalent to
245+
#
246+
#struct = np.ones((box_dim, box_dim)).astype(bool)
247+
#dilated = ndimage.binary_dilation(only_sat, structure=struct).astype(only_sat.dtype)
248+
#
249+
# The explicit loop over the box, followed by taking care of the
250+
# array edges, turns out to be faster by around an order of magnitude.
251+
# There must be poor coding in the underlying routine for
252+
# ndimage.binary_dilation as of scipy 1.14.1.
253+
254+
for i in range(box_dim):
255+
for j in range(box_dim):
256+
257+
# Explicit binary dilation over the inner ('valid')
258+
# region of the convolution/filter
259+
260+
i2 = only_sat.shape[0] - box_dim + i + 1
261+
j2 = only_sat.shape[1] - box_dim + j + 1
262+
263+
k1, k2, l1, l2 = [n_pix_grow_sat, -n_pix_grow_sat,
264+
n_pix_grow_sat, -n_pix_grow_sat]
265+
266+
dilated[k1:k2, l1:l2] |= only_sat[i:i2, j:j2]
267+
268+
for i in range(n_pix_grow_sat - 1, -1, -1):
269+
for j in range(i + n_pix_grow_sat, -1, -1):
270+
271+
# March from the limit of the 'valid' region toward
272+
# each edge. Maximum filter ensures correct dilation.
273+
274+
dilated[i] |= ndimage.maximum_filter(only_sat[j], box_dim)
275+
dilated[:, i] |= ndimage.maximum_filter(only_sat[:, j], box_dim)
276+
dilated[-i - 1] |= ndimage.maximum_filter(only_sat[-j - 1], box_dim)
277+
dilated[:, -i - 1] |= ndimage.maximum_filter(only_sat[:, -j - 1], box_dim)
278+
279+
cgdq[dilated] |= saturated
280+
281+
if inplace:
282+
return None
283+
else:
284+
return cgdq
285+
218286

219287

220288
def plane_saturation(plane, sat_thresh, dqflags):

0 commit comments

Comments
 (0)