Skip to content

Commit 63c4c8f

Browse files
committed
cube single_mask fix
1 parent cffbd85 commit 63c4c8f

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

maximask_and_maxitrack/maximask/maximask.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
if self.class_flags[k]:
103103
class_idx.append(k)
104104
self.class_idx = np.array(class_idx, dtype=np.int32)
105+
self.effective_class_nb = np.sum(class_flags)
105106

106107
# builds the prior factor values to be given for inference
107108
self.priors = priors
@@ -298,11 +299,11 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
298299
preds = np.zeros_like(hdu_data, dtype=np.int16)
299300
elif self.thresholds is not None:
300301
preds = np.zeros(
301-
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.uint8
302+
list(hdu_shape) + [self.effective_class_nb], dtype=np.uint8
302303
)
303304
else:
304305
preds = np.zeros(
305-
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.float32
306+
list(hdu_shape) + [self.effective_class_nb], dtype=np.float32
306307
)
307308

308309
# preprocessing
@@ -315,29 +316,35 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
315316
# process the HDU 3D or 2D data
316317
if len(hdu_shape) == 3:
317318
c, h, w = hdu_shape
318-
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):
319319

320-
# make temporary 2D prediction array to get results by reference
321-
if self.sing_mask:
322-
tmp_preds = np.zeros_like([h, w], dtype=np.int16)
323-
elif self.thresholds is not None:
324-
tmp_preds = np.zeros(
325-
[h, w, np.sum(self.class_flags)], dtype=np.uint8
326-
)
327-
else:
328-
tmp_preds = np.zeros(
329-
[h, w, np.sum(self.class_flags)], dtype=np.float32
330-
)
331-
332-
# make predictions and forward them to the final prediction array
320+
# make temporary 2D prediction array to get results by reference for each channel individually
321+
if self.sing_mask:
322+
tmp_preds = np.zeros([h, w], dtype=np.int16)
323+
elif self.thresholds is not None:
324+
tmp_preds = np.zeros(
325+
[h, w, self.effective_class_nb], dtype=np.uint8
326+
)
327+
else:
328+
tmp_preds = np.zeros(
329+
[h, w, self.effective_class_nb], dtype=np.float32
330+
)
331+
332+
# make predictions and forward them to the final prediction array
333+
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):
333334
ch_im_data = hdu_data[ch]
334335
self.process_image(ch_im_data, tmp_preds, tf_model)
335336
preds[ch] = tmp_preds
336337

338+
if not self.sing_mask:
339+
preds = np.transpose(preds, (0, 3, 1, 2))
340+
337341
elif len(hdu_shape) == 2:
338342
self.process_image(hdu_data, preds, tf_model)
339343

340-
return preds
344+
if not self.sing_mask:
345+
preds = np.transpose(preds, (2, 0, 1))
346+
347+
return preds
341348

342349
def process_image(self, im_data, preds, tf_model):
343350
"""Processes 2D image data.
@@ -371,11 +378,6 @@ def process_image(self, im_data, preds, tf_model):
371378
batch_coord_list = block_coord_list[-rest:]
372379
self.process_batch(im_data, preds, tf_model, batch_coord_list)
373380

374-
if not self.sing_mask:
375-
preds = np.transpose(preds, (2, 0, 1))
376-
377-
return preds
378-
379381
def get_block_coords(self, h, w):
380382
"""Gets the coordinate list of blocks to process.
381383

0 commit comments

Comments
 (0)