From 63c4c8f56929171140bba9a481900f97917e9cdf Mon Sep 17 00:00:00 2001 From: mpaillassa Date: Fri, 22 Mar 2024 15:50:55 +0900 Subject: [PATCH] cube single_mask fix --- maximask_and_maxitrack/maximask/maximask.py | 46 +++++++++++---------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/maximask_and_maxitrack/maximask/maximask.py b/maximask_and_maxitrack/maximask/maximask.py index 7487777..cf2253c 100644 --- a/maximask_and_maxitrack/maximask/maximask.py +++ b/maximask_and_maxitrack/maximask/maximask.py @@ -102,6 +102,7 @@ def __init__( if self.class_flags[k]: class_idx.append(k) self.class_idx = np.array(class_idx, dtype=np.int32) + self.effective_class_nb = np.sum(class_flags) # builds the prior factor values to be given for inference self.priors = priors @@ -298,11 +299,11 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model): preds = np.zeros_like(hdu_data, dtype=np.int16) elif self.thresholds is not None: preds = np.zeros( - list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.uint8 + list(hdu_shape) + [self.effective_class_nb], dtype=np.uint8 ) else: preds = np.zeros( - list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.float32 + list(hdu_shape) + [self.effective_class_nb], dtype=np.float32 ) # preprocessing @@ -315,29 +316,35 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model): # process the HDU 3D or 2D data if len(hdu_shape) == 3: c, h, w = hdu_shape - for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"): - # make temporary 2D prediction array to get results by reference - if self.sing_mask: - tmp_preds = np.zeros_like([h, w], dtype=np.int16) - elif self.thresholds is not None: - tmp_preds = np.zeros( - [h, w, np.sum(self.class_flags)], dtype=np.uint8 - ) - else: - tmp_preds = np.zeros( - [h, w, np.sum(self.class_flags)], dtype=np.float32 - ) - - # make predictions and forward them to the final prediction array + # make temporary 2D prediction array to get results by reference for each channel individually + if self.sing_mask: + tmp_preds = np.zeros([h, w], dtype=np.int16) + elif self.thresholds is not None: + tmp_preds = np.zeros( + [h, w, self.effective_class_nb], dtype=np.uint8 + ) + else: + tmp_preds = np.zeros( + [h, w, self.effective_class_nb], dtype=np.float32 + ) + + # make predictions and forward them to the final prediction array + for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"): ch_im_data = hdu_data[ch] self.process_image(ch_im_data, tmp_preds, tf_model) preds[ch] = tmp_preds + if not self.sing_mask: + preds = np.transpose(preds, (0, 3, 1, 2)) + elif len(hdu_shape) == 2: self.process_image(hdu_data, preds, tf_model) - return preds + if not self.sing_mask: + preds = np.transpose(preds, (2, 0, 1)) + + return preds def process_image(self, im_data, preds, tf_model): """Processes 2D image data. @@ -371,11 +378,6 @@ def process_image(self, im_data, preds, tf_model): batch_coord_list = block_coord_list[-rest:] self.process_batch(im_data, preds, tf_model, batch_coord_list) - if not self.sing_mask: - preds = np.transpose(preds, (2, 0, 1)) - - return preds - def get_block_coords(self, h, w): """Gets the coordinate list of blocks to process.