Skip to content

Commit b65f560

Browse files
authored
JP-3669: Updating the C Extension to do CHARGELOSS Read Noise Recalculations (#275)
2 parents c23ac15 + d85f859 commit b65f560

10 files changed

+508
-74
lines changed

Diff for: .gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ __pycache__/
44
*$py.class
55
*~
66

7+
# Temp files
8+
*.*.swp
9+
710
# C extensions
811
*.so
912

Diff for: CHANGES.rst

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ General
2323

2424
- Add TweakReg submodule. [#267]
2525

26+
ramp_fitting
27+
~~~~~~~~~~~~
28+
29+
- Move the CHARGELOSS read noise variance recalculation from the JWST step
30+
code to the C extension to simplify the code and improve performance.[#275]
31+
2632
Changes to API
2733
--------------
2834

Diff for: changes/275.general.rst

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[ramp_fitting] Moving the read noise recalculation due to CHARGELOSS flagging from
2+
the JWST ramp fit step code into the STCAL ramp fit C-extension.

Diff for: src/stcal/ramp_fitting/ols_fit.py

+7
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,7 @@ def discard_miri_groups(ramp_data):
920920
data = ramp_data.data
921921
err = ramp_data.err
922922
groupdq = ramp_data.groupdq
923+
orig_gdq = ramp_data.orig_gdq
923924

924925
n_int, ngroups, nrows, ncols = data.shape
925926

@@ -949,6 +950,8 @@ def discard_miri_groups(ramp_data):
949950
if num_bad_slices > 0:
950951
data = data[:, num_bad_slices:, :, :]
951952
err = err[:, num_bad_slices:, :, :]
953+
if orig_gdq is not None:
954+
orig_gdq = orig_gdq[:, num_bad_slices:, :, :]
952955

953956
log.info("Number of leading groups that are flagged as DO_NOT_USE: %s", num_bad_slices)
954957

@@ -968,6 +971,8 @@ def discard_miri_groups(ramp_data):
968971
data = data[:, :-1, :, :]
969972
err = err[:, :-1, :, :]
970973
groupdq = groupdq[:, :-1, :, :]
974+
if orig_gdq is not None:
975+
orig_gdq = orig_gdq[:, :-1, :, :]
971976

972977
log.info("MIRI dataset has all pixels in the final group flagged as DO_NOT_USE.")
973978

@@ -981,6 +986,8 @@ def discard_miri_groups(ramp_data):
981986
ramp_data.data = data
982987
ramp_data.err = err
983988
ramp_data.groupdq = groupdq
989+
if orig_gdq is not None:
990+
ramp_data.orig_gdq = orig_gdq
984991

985992
return True
986993

Diff for: src/stcal/ramp_fitting/ramp_fit.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
BUFSIZE = 1024 * 300000 # 300Mb cache size for data section
3131

3232

33-
def create_ramp_fit_class(model, dqflags=None, suppress_one_group=False):
33+
def create_ramp_fit_class(model, algorithm, dqflags=None, suppress_one_group=False):
3434
"""
3535
Create an internal ramp fit class from a data model.
3636
@@ -58,11 +58,24 @@ def create_ramp_fit_class(model, dqflags=None, suppress_one_group=False):
5858
else:
5959
dark_current_array = model.average_dark_current
6060

61+
orig_gdq = None
62+
if algorithm.upper() == "OLS_C":
63+
wh_chargeloss = np.where(np.bitwise_and(model.groupdq.astype(np.uint32), dqflags['CHARGELOSS']))
64+
if len(wh_chargeloss[0]) > 0:
65+
orig_gdq = model.groupdq.copy()
66+
del wh_chargeloss
67+
6168
if isinstance(model.data, u.Quantity):
6269
ramp_data.set_arrays(model.data.value, model.err.value, model.groupdq,
6370
model.pixeldq, dark_current_array)
6471
else:
65-
ramp_data.set_arrays(model.data, model.err, model.groupdq, model.pixeldq, dark_current_array)
72+
ramp_data.set_arrays(
73+
model.data,
74+
model.err,
75+
model.groupdq,
76+
model.pixeldq,
77+
dark_current_array,
78+
orig_gdq)
6679

6780
# Attribute may not be supported by all pipelines. Default is NoneType.
6881
drop_frames1 = model.meta.exposure.drop_frames1 if hasattr(model, "drop_frames1") else None
@@ -78,6 +91,7 @@ def create_ramp_fit_class(model, dqflags=None, suppress_one_group=False):
7891
if "zero_frame" in model.meta.exposure and model.meta.exposure.zero_frame:
7992
ramp_data.zeroframe = model.zeroframe
8093

94+
ramp_data.algorithm = algorithm
8195
ramp_data.set_dqflags(dqflags)
8296
ramp_data.start_row = 0
8397
ramp_data.num_rows = ramp_data.data.shape[2]
@@ -170,7 +184,7 @@ def ramp_fit(
170184
# Create an instance of the internal ramp class, using only values needed
171185
# for ramp fitting from the to remove further ramp fitting dependence on
172186
# data models.
173-
ramp_data = create_ramp_fit_class(model, dqflags, suppress_one_group)
187+
ramp_data = create_ramp_fit_class(model, algorithm, dqflags, suppress_one_group)
174188

175189
if algorithm.upper() == "OLS_C":
176190
ramp_data.run_c_code = True

Diff for: src/stcal/ramp_fitting/ramp_fit_class.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ def __init__(self):
1010
self.pixeldq = None
1111
self.average_dark_current = None
1212

13+
# Needed for CHARGELOSS recomputation
14+
self.orig_gdq = None
15+
self.algorithm = None
16+
1317
# Meta information
1418
self.instrument_name = None
1519

@@ -25,6 +29,7 @@ def __init__(self):
2529
self.flags_saturated = None
2630
self.flags_no_gain_val = None
2731
self.flags_unreliable_slope = None
32+
self.flags_chargeloss = None
2833

2934
# ZEROFRAME
3035
self.zframe_mat = None
@@ -41,13 +46,15 @@ def __init__(self):
4146

4247
# C code debugging switch.
4348
self.run_c_code = False
49+
self.run_chargeloss = True
50+
# self.run_chargeloss = False
4451

4552
self.one_groups_locs = None # One good group locations.
4653
self.one_groups_time = None # Time to use for one good group ramps.
4754

4855
self.current_integ = -1
4956

50-
def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current):
57+
def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current, orig_gdq=None):
5158
"""
5259
Set the arrays needed for ramp fitting.
5360
@@ -72,6 +79,11 @@ def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current):
7279
average_dark_current : ndarray (float32)
7380
2-D array containing the average dark current. It has
7481
dimensions (nrows, ncols)
82+
83+
orig_gdq : ndarray
84+
4-D array containing a copy of the original group DQ array. Since
85+
the group DQ array can be modified during ramp fitting, this keeps
86+
around the original group DQ flags passed to ramp fitting.
7587
"""
7688
# Get arrays from the data model
7789
self.data = data
@@ -80,6 +92,8 @@ def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current):
8092
self.pixeldq = pixeldq
8193
self.average_dark_current = average_dark_current
8294

95+
self.orig_gdq = orig_gdq
96+
8397
def set_meta(self, name, frame_time, group_time, groupgap, nframes, drop_frames1=None):
8498
"""
8599
Set the metainformation needed for ramp fitting.
@@ -131,6 +145,8 @@ def set_dqflags(self, dqflags):
131145
self.flags_saturated = dqflags["SATURATED"]
132146
self.flags_no_gain_val = dqflags["NO_GAIN_VALUE"]
133147
self.flags_unreliable_slope = dqflags["UNRELIABLE_SLOPE"]
148+
if self.algorithm is not None and self.algorithm.upper() == "OLS_C":
149+
self.flags_chargeloss = dqflags["CHARGELOSS"]
134150

135151
def dbg_print_types(self):
136152
# Arrays from the data model
@@ -200,6 +216,16 @@ def dbg_print_pixel_info(self, row, col):
200216
# print(f" err :\n{self.err[:, :, row, col]}")
201217
# print(f" pixeldq :\n{self.pixeldq[row, col]}")
202218

219+
def dbg_print_info(self):
220+
print(" ")
221+
nints, ngroups, nrows, ncols = self.data.shape
222+
for row in range(nrows):
223+
for col in range(ncols):
224+
print("=" * 80)
225+
print(f"**** Pixel ({row}, {col}) ****")
226+
self.dbg_print_pixel_info(row, col)
227+
print("=" * 80)
228+
203229
def dbg_write_ramp_data_pix_pre(self, fname, row, col, fd):
204230
fd.write("def create_ramp_data_pixel():\n")
205231
indent = INDENT

0 commit comments

Comments
 (0)