Skip to content

Commit cffbd85

Browse files
authored
Merge pull request #31 from mpaillassa/cube_inputs
Cube hdu input support
2 parents 838d9e6 + 5cad7f4 commit cffbd85

File tree

3 files changed

+152
-60
lines changed

3 files changed

+152
-60
lines changed

maximask_and_maxitrack/maximask/maximask.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -290,49 +290,92 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
290290
elif task == "process":
291291

292292
# prediction array
293-
h, w = hdu_data.shape
293+
hdu_shape = hdu_data.shape
294294
if np.all(hdu_data == 0):
295295
return np.zeros_like(hdu_data, dtype=np.uint8)
296296
else:
297297
if self.sing_mask:
298-
preds = np.zeros([h, w], dtype=np.int16)
298+
preds = np.zeros_like(hdu_data, dtype=np.int16)
299299
elif self.thresholds is not None:
300-
preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.uint8)
300+
preds = np.zeros(
301+
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.uint8
302+
)
301303
else:
302-
preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.float32)
303-
304-
# get list of block coordinate to process
305-
block_coord_list = self.get_block_coords(h, w)
304+
preds = np.zeros(
305+
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.float32
306+
)
306307

307308
# preprocessing
308309
log.info("Preprocessing...")
309310
hdu_data, t = utils.image_norm(hdu_data)
310-
log.info(f"Preprocessing done in {t:.2f}s, {h*w/(t*1e06):.2f}MPix/s")
311-
312-
# process all the blocks by batches
313-
# the process_batch method writes the predictions in preds by reference
314-
nb_blocks = len(block_coord_list)
315-
if nb_blocks <= self.batch_size:
316-
# only one (possibly not full) batch to process
317-
self.process_batch(hdu_data, preds, tf_model, block_coord_list)
318-
else:
319-
# several batches to process + one last possibly not full
320-
nb_batch = nb_blocks // self.batch_size
321-
for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "):
322-
batch_coord_list = block_coord_list[
323-
b * self.batch_size : (b + 1) * self.batch_size
324-
]
325-
self.process_batch(hdu_data, preds, tf_model, batch_coord_list)
326-
rest = nb_blocks - nb_batch * self.batch_size
327-
if rest:
328-
batch_coord_list = block_coord_list[-rest:]
329-
self.process_batch(hdu_data, preds, tf_model, batch_coord_list)
330-
331-
if not self.sing_mask:
332-
preds = np.transpose(preds, (2, 0, 1))
311+
log.info(
312+
f"Preprocessing done in {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s"
313+
)
314+
315+
# process the HDU 3D or 2D data
316+
if len(hdu_shape) == 3:
317+
c, h, w = hdu_shape
318+
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):
319+
320+
# make temporary 2D prediction array to get results by reference
321+
if self.sing_mask:
322+
tmp_preds = np.zeros_like([h, w], dtype=np.int16)
323+
elif self.thresholds is not None:
324+
tmp_preds = np.zeros(
325+
[h, w, np.sum(self.class_flags)], dtype=np.uint8
326+
)
327+
else:
328+
tmp_preds = np.zeros(
329+
[h, w, np.sum(self.class_flags)], dtype=np.float32
330+
)
331+
332+
# make predictions and forward them to the final prediction array
333+
ch_im_data = hdu_data[ch]
334+
self.process_image(ch_im_data, tmp_preds, tf_model)
335+
preds[ch] = tmp_preds
336+
337+
elif len(hdu_shape) == 2:
338+
self.process_image(hdu_data, preds, tf_model)
333339

334340
return preds
335341

342+
def process_image(self, im_data, preds, tf_model):
343+
"""Processes 2D image data.
344+
345+
Args:
346+
im_data (np.ndarray): 2D image data to process.
347+
preds (np.ndarray): corresponding 2D MaxiMask predictions to fill.
348+
tf_model (tf.keras.Model): MaxiMask tensorflow model.
349+
"""
350+
351+
# get list of block coordinate to process
352+
h, w = im_data.shape
353+
block_coord_list = self.get_block_coords(h, w)
354+
355+
# process all the blocks by batches
356+
# the process_batch method writes the predictions in preds by reference
357+
nb_blocks = len(block_coord_list)
358+
if nb_blocks <= self.batch_size:
359+
# only one (possibly not full) batch to process
360+
self.process_batch(im_data, preds, tf_model, block_coord_list)
361+
else:
362+
# several batches to process + one last possibly not full
363+
nb_batch = nb_blocks // self.batch_size
364+
for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "):
365+
batch_coord_list = block_coord_list[
366+
b * self.batch_size : (b + 1) * self.batch_size
367+
]
368+
self.process_batch(im_data, preds, tf_model, batch_coord_list)
369+
rest = nb_blocks - nb_batch * self.batch_size
370+
if rest:
371+
batch_coord_list = block_coord_list[-rest:]
372+
self.process_batch(im_data, preds, tf_model, batch_coord_list)
373+
374+
if not self.sing_mask:
375+
preds = np.transpose(preds, (2, 0, 1))
376+
377+
return preds
378+
336379
def get_block_coords(self, h, w):
337380
"""Gets the coordinate list of blocks to process.
338381

maximask_and_maxitrack/maxitrack/maxitrack.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ def process_file(self, file_name, tf_model):
8181
tf_model (tf.keras.Model): MaxiTrack tensorflow model.
8282
"""
8383

84-
all_preds = []
85-
8684
# make hdu tasks
8785
hdu_task_list = self.make_hdu_tasks(file_name)
8886

@@ -102,25 +100,34 @@ def process_file(self, file_name, tf_model):
102100
else:
103101
log.info("Using all available HDUs")
104102

105-
# go through all HDUs
106-
for hdu_idx, hdu_type, hdu_shape in hdu_task_list:
107-
log.info(f"HDU {hdu_idx}")
103+
# go through all HDUs and write results
104+
all_2d_hdu_preds = []
105+
with open("maxitrack.out", "a") as fd:
108106

109-
# get raw predictions
110-
preds, t = self.process_hdu(file_name, hdu_idx, tf_model)
111-
log.info(
112-
f"Whole processing time (incl. preprocessing): {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s"
113-
)
107+
for hdu_idx, hdu_shape in hdu_task_list:
108+
log.info(f"HDU {hdu_idx}")
114109

115-
# append the results
116-
for pred in preds:
117-
all_preds.append(pred)
110+
# get raw predictions
111+
preds, t = self.process_hdu(file_name, hdu_idx, tf_model)
112+
log.info(
113+
f"Whole processing time (incl. preprocessing): {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s"
114+
)
118115

119-
final_res = np.mean(all_preds)
116+
# if this is a 3D HDU, outputs a score per channel image
117+
if len(preds) > 1:
118+
for ch in range(len(preds)):
119+
fd.write(
120+
f"{file_name} HDU {hdu_idx} Channel {ch} {preds[ch]:.4f}\n"
121+
)
122+
# if this is a 2D HDU, consider this is the same field over all 2D HDUs and aggregate a score over them
123+
elif len(preds) == 1:
124+
all_2d_hdu_preds.append(preds)
125+
126+
# write the aggregated score of 2D HDUs
127+
if len(all_2d_hdu_preds):
128+
final_pred = np.mean(all_2d_hdu_preds)
129+
fd.write(f"{file_name} {final_pred:.4f}\n")
120130

121-
# write file
122-
with open("maxitrack.out", "a") as fd:
123-
fd.write(f"{file_name} {final_res:.4f}\n")
124131
else:
125132
log.info(f"Skipping {file_name} because no HDU was found to be processed")
126133

@@ -145,7 +152,7 @@ def make_hdu_tasks(self, file_name):
145152
check, hdu_type = utils.check_hdu(specified_hdu, self.im_size)
146153
if check:
147154
hdu_shape = specified_hdu.data.shape
148-
hdu_task_list.append([spec_hdu_idx, hdu_type, hdu_shape])
155+
hdu_task_list.append([spec_hdu_idx, hdu_shape])
149156
else:
150157
log.info(
151158
f"Ignoring HDU {spec_hdu_idx} because not adequate data format"
@@ -157,7 +164,7 @@ def make_hdu_tasks(self, file_name):
157164
check, hdu_type = utils.check_hdu(file_hdu[k], self.im_size)
158165
if check:
159166
hdu_shape = file_hdu[k].data.shape
160-
hdu_task_list.append([k, hdu_type, hdu_shape])
167+
hdu_task_list.append([k, hdu_shape])
161168
else:
162169
log.info(f"Ignoring HDU {k} because not adequate data format")
163170

@@ -172,7 +179,7 @@ def process_hdu(self, file_name, hdu_idx, tf_model):
172179
hdu_idx (int): index of the HDU to process.
173180
tf_model (tf.keras.Model): MaxiTrack tensorflow model.
174181
Returns:
175-
out_array (np.ndarray): MaxiTrack predictions over the image.
182+
outputs (list): MaxiTrack predictions over image and channels if 3D data.
176183
"""
177184

178185
# make file name
@@ -184,15 +191,43 @@ def process_hdu(self, file_name, hdu_idx, tf_model):
184191
with fits.open(file_name) as file_hdu:
185192
hdu = file_hdu[hdu_idx]
186193
im_data = hdu.data
187-
188-
# get list of block coordinate to process
189-
h, w = im_data.shape
190-
block_coord_list = self.get_block_coords(h, w)
194+
im_data_shape = im_data.shape
191195

192196
# preprocessing
193197
log.info("Preprocessing...")
194198
im_data, t = utils.image_norm(im_data)
195-
log.info(f"Preprocessing done in {t:.2f}s, {h*w/(t*1e06):.2f}MPix/s")
199+
log.info(
200+
f"Preprocessing done in {t:.2f}s, {np.prod(im_data_shape)/(t*1e06):.2f}MPix/s"
201+
)
202+
203+
# process the HDU 3D or 2D data
204+
outputs = []
205+
if len(im_data_shape) == 3:
206+
c = im_data.shape[0]
207+
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):
208+
ch_im_data = im_data[ch]
209+
predictions = self.process_image(ch_im_data, tf_model)
210+
outputs.append(np.mean(predictions))
211+
212+
elif len(im_data_shape) == 2:
213+
predictions = self.process_image(im_data, tf_model)
214+
outputs.append(np.mean(predictions))
215+
216+
return outputs
217+
218+
def process_image(self, im_data, tf_model):
219+
"""Processes 2D image data.
220+
221+
Args:
222+
im_data (np.ndarray): 2D image data to process.
223+
tf_model (tf.keras.Model): MaxiTrack tensorflow model.
224+
Returns:
225+
out_array (np.ndarray): MaxiTrack predictions over the image.
226+
"""
227+
228+
# get list of block coordinate to process
229+
h, w = im_data.shape
230+
block_coord_list = self.get_block_coords(h, w)
196231

197232
# process all the blocks by batches
198233
nb_blocks = len(block_coord_list)

maximask_and_maxitrack/utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,16 @@ def check_hdu(hdu, min_size):
149149
(bool): whether the hdu is to be processed or not.
150150
"""
151151

152+
# get HDU information
152153
infos = hdu._summary()
154+
155+
# check size validity
153156
ds = infos[4]
154-
size_b = len(ds) == 2 and ds[0] > min_size and ds[1] > min_size
157+
size_b_2d = len(ds) == 2 and ds[0] > min_size and ds[1] > min_size
158+
size_b_3d = len(ds) == 3 and ds[1] > min_size and ds[2] > min_size
159+
size_b = size_b_2d or size_b_3d
160+
161+
# check data type validity
155162
dt = infos[5]
156163
data_type_b = (
157164
"float16" in dt
@@ -347,9 +354,16 @@ def image_norm(im):
347354
np.place(im, im > 80000, 80000)
348355
np.place(im, im < -500, -500)
349356

350-
# normalization
351-
bg_map, si_map = background_est(im)
352-
im -= bg_map
353-
im /= si_map
357+
# normalize single image or all channels if 3d
358+
im_shape = im.shape
359+
if len(im_shape) == 3:
360+
for ch in range(im_shape[0]):
361+
bg_map, si_map = background_est(im[ch])
362+
im[ch] -= bg_map
363+
im[ch] /= si_map
364+
elif len(im_shape) == 2:
365+
bg_map, si_map = background_est(im)
366+
im -= bg_map
367+
im /= si_map
354368

355369
return im

0 commit comments

Comments
 (0)