Skip to content

Commit

Permalink
Saturation speedup (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
melanieclarke authored Feb 12, 2025
2 parents 651015d + 5cecab4 commit b35bccc
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 33 deletions.
1 change: 1 addition & 0 deletions changes/331.general.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Performance improvements for saturation step
134 changes: 101 additions & 33 deletions src/stcal/saturation/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,64 +83,85 @@ def flag_saturated_pixels(

for ints in range(nints):
# Work forward through the groups for initial pass at saturation

# We want to flag saturation in all subsequent groups after
# the one in which it was found. Use this boolean array to
# keep a running tally of pixels that have saturated.
previously_saturated = np.zeros(shape=(nrows, ncols), dtype='bool')

for group in range(ngroups):
plane = data[ints, group, :, :]

flagarray, flaglowarray = plane_saturation(plane, sat_thresh, dqflags)

# for saturation, the flag is set in the current plane
# and all following planes.
np.bitwise_or(gdq[ints, group:, :, :], flagarray, gdq[ints, group:, :, :])

# Update the running tally of all pixels that have ever
# experienced saturation to account for this.

previously_saturated |= (plane >= sat_thresh)
flagarray = (previously_saturated * saturated).astype(np.uint32)

gdq[ints, group, :, :] |= flagarray

# for A/D floor, the flag is only set of the current plane
np.bitwise_or(gdq[ints, group, :, :], flaglowarray, gdq[ints, group, :, :])
flaglowarray = ((plane <= 0)*(ad_floor | dnu)).astype(np.uint32)

gdq[ints, group, :, :] |= flaglowarray

del flagarray
del flaglowarray

# now, flag any pixels that border saturated pixels (not A/D floor pix)
if n_pix_grow_sat > 0:
gdq_slice = copy.copy(gdq[ints, group, :, :]).astype(int)

gdq[ints, group, :, :] = adjacent_pixels(gdq_slice, saturated, n_pix_grow_sat)
gdq_slice = gdq[ints, group, :, :]
adjacent_pixels(gdq_slice, saturated, n_pix_grow_sat, inplace=True)

# Work backward through the groups for a second pass at saturation
# This is to flag things that actually saturated in prior groups but
# were not obvious because of group averaging
for group in range(ngroups-2, -1, -1):

for group in range(ngroups - 2, -1, -1):

plane = data[ints, group, :, :]
thisdq = gdq[ints, group, :, :]
nextdq = gdq[ints, group + 1, :, :]

# Determine the dilution factor due to group averaging

# No point in this step if the dilution factor is 1. In
# that case, there is no way that we would have missed
# saturation before but flag it now, since the threshold
# would be the same.

if read_pattern is not None:
# Single value dilution factor for this group
dilution_factor = np.mean(read_pattern[group]) / read_pattern[group][-1]
if dilution_factor == 1:
continue
# Broadcast to array size
dilution_factor = np.where(no_sat_check_mask, 1, dilution_factor)
else:
dilution_factor = 1
continue

# Find where this plane looks like it might saturate given the dilution factor
flagarray, _ = plane_saturation(plane, sat_thresh * dilution_factor, dqflags)
# Find where this plane looks like it might saturate given
# the dilution factor, *and* this group did not already get
# flagged as saturated or do not use, *and* the next group
# was flagged as saturated. Result of the line below is a
# boolean array.

# Find the overlap of where this plane looks like it might saturate, was not currently
# flagged as saturation or DO_NOT_USE, and the next group had saturation flagged.
indx = np.where((np.bitwise_and(flagarray, saturated) != 0) & \
(np.bitwise_and(thisdq, saturated) == 0) & \
(np.bitwise_and(thisdq, dnu) == 0) & \
(np.bitwise_and(nextdq, saturated) != 0))
partial_sat = ((plane >= sat_thresh*dilution_factor) & \
(thisdq & (saturated | dnu) == 0) & \
(nextdq & saturated != 0))

# Reset flag array to only pixels passing this gauntlet
flagarray[:] = 0
flagarray[indx] = dnu
flagarray = (partial_sat * dnu).astype(np.uint32)

# Grow the newly-flagged saturating pixels
if n_pix_grow_sat > 0:
flagarray = adjacent_pixels(flagarray, dnu, n_pix_grow_sat)
adjacent_pixels(flagarray, dnu, n_pix_grow_sat, inplace=True)

# Add them to the gdq array
np.bitwise_or(gdq[ints, group, :, :], flagarray, gdq[ints, group, :, :])
gdq[ints, group, :, :] |= flagarray

# Add an additional pass to look for things saturating in the second group
# that can be particularly tricky to identify
Expand All @@ -160,25 +181,24 @@ def flag_saturated_pixels(
mask &= scigp2 > sat_thresh / len(read_pattern[1])

# Identify groups that are saturated in the third group but not yet flagged in the second
gp3mask = np.where((np.bitwise_and(dq3, saturated) != 0) & \
(np.bitwise_and(dq2, saturated) == 0), True, False)
gp3mask = ((np.bitwise_and(dq3, saturated) != 0) & \
(np.bitwise_and(dq2, saturated) == 0))
mask &= gp3mask

# Flag the 2nd group for the pixels passing that gauntlet
flagarray = np.zeros_like(mask,dtype='uint8')
flagarray[mask] = dnu
flagarray = (mask * dnu).astype(np.uint32)

# Add them to the gdq array
np.bitwise_or(gdq[ints, 1, :, :], flagarray, gdq[ints, 1, :, :])


# Check ZEROFRAME.
if zframe is not None:
plane = zframe[ints, :, :]
flagarray, flaglowarray = plane_saturation(plane, sat_thresh, dqflags)
zdq = flagarray | flaglowarray
if n_pix_grow_sat > 0:
zdq = adjacent_pixels(zdq, saturated, n_pix_grow_sat)
adjacent_pixels(zdq, saturated, n_pix_grow_sat, inplace=True)
plane[zdq != 0] = 0.0
zframe[ints] = plane

Expand All @@ -192,7 +212,7 @@ def flag_saturated_pixels(
return gdq, pdq, zframe


def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat):
def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat=1, inplace=False):
"""
plane_gdq : ndarray
The data quality flags of the current.
Expand All @@ -204,17 +224,65 @@ def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat):
Number of pixels that each flagged saturated pixel should be 'grown',
to account for charge spilling. Default is 1.
inplace : bool
Update plane_gdq in place, returning None? Default False.
Return
------
sat_pix : ndarray
The saturated pixels in the current plane.
"""
cgdq = plane_gdq.copy()
only_sat = np.bitwise_and(plane_gdq, saturated).astype(np.uint8)
if not inplace:
cgdq = plane_gdq.copy()
else:
cgdq = plane_gdq

only_sat = plane_gdq & saturated > 0
dilated = only_sat.copy()
box_dim = (n_pix_grow_sat * 2) + 1
struct = np.ones((box_dim, box_dim)).astype(bool)
dialated = ndimage.binary_dilation(only_sat, structure=struct).astype(only_sat.dtype)
return np.bitwise_or(cgdq, (dialated * saturated))

# The for loops below are equivalent to
#
#struct = np.ones((box_dim, box_dim)).astype(bool)
#dilated = ndimage.binary_dilation(only_sat, structure=struct).astype(only_sat.dtype)
#
# The explicit loop over the box, followed by taking care of the
# array edges, turns out to be faster by around an order of magnitude.
# There must be poor coding in the underlying routine for
# ndimage.binary_dilation as of scipy 1.14.1.

for i in range(box_dim):
for j in range(box_dim):

# Explicit binary dilation over the inner ('valid')
# region of the convolution/filter

i2 = only_sat.shape[0] - box_dim + i + 1
j2 = only_sat.shape[1] - box_dim + j + 1

k1, k2, l1, l2 = [n_pix_grow_sat, -n_pix_grow_sat,
n_pix_grow_sat, -n_pix_grow_sat]

dilated[k1:k2, l1:l2] |= only_sat[i:i2, j:j2]

for i in range(n_pix_grow_sat - 1, -1, -1):
for j in range(i + n_pix_grow_sat, -1, -1):

# March from the limit of the 'valid' region toward
# each edge. Maximum filter ensures correct dilation.

dilated[i] |= ndimage.maximum_filter(only_sat[j], box_dim)
dilated[:, i] |= ndimage.maximum_filter(only_sat[:, j], box_dim)
dilated[-i - 1] |= ndimage.maximum_filter(only_sat[-j - 1], box_dim)
dilated[:, -i - 1] |= ndimage.maximum_filter(only_sat[:, -j - 1], box_dim)

cgdq[dilated] |= saturated

if inplace:
return None
else:
return cgdq



def plane_saturation(plane, sat_thresh, dqflags):
Expand Down

0 comments on commit b35bccc

Please sign in to comment.