Skip to content

Commit

Permalink
cube single_mask fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mpaillassa committed Mar 22, 2024
1 parent cffbd85 commit 63c4c8f
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions maximask_and_maxitrack/maximask/maximask.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
if self.class_flags[k]:
class_idx.append(k)
self.class_idx = np.array(class_idx, dtype=np.int32)
self.effective_class_nb = np.sum(class_flags)

# builds the prior factor values to be given for inference
self.priors = priors
Expand Down Expand Up @@ -298,11 +299,11 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
preds = np.zeros_like(hdu_data, dtype=np.int16)
elif self.thresholds is not None:
preds = np.zeros(
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.uint8
list(hdu_shape) + [self.effective_class_nb], dtype=np.uint8
)
else:
preds = np.zeros(
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.float32
list(hdu_shape) + [self.effective_class_nb], dtype=np.float32
)

# preprocessing
Expand All @@ -315,29 +316,35 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
# process the HDU 3D or 2D data
if len(hdu_shape) == 3:
c, h, w = hdu_shape
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):

# make temporary 2D prediction array to get results by reference
if self.sing_mask:
tmp_preds = np.zeros_like([h, w], dtype=np.int16)
elif self.thresholds is not None:
tmp_preds = np.zeros(
[h, w, np.sum(self.class_flags)], dtype=np.uint8
)
else:
tmp_preds = np.zeros(
[h, w, np.sum(self.class_flags)], dtype=np.float32
)

# make predictions and forward them to the final prediction array
# make temporary 2D prediction array to get results by reference for each channel individually
if self.sing_mask:
tmp_preds = np.zeros([h, w], dtype=np.int16)
elif self.thresholds is not None:
tmp_preds = np.zeros(
[h, w, self.effective_class_nb], dtype=np.uint8
)
else:
tmp_preds = np.zeros(
[h, w, self.effective_class_nb], dtype=np.float32
)

# make predictions and forward them to the final prediction array
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):
ch_im_data = hdu_data[ch]
self.process_image(ch_im_data, tmp_preds, tf_model)
preds[ch] = tmp_preds

if not self.sing_mask:
preds = np.transpose(preds, (0, 3, 1, 2))

elif len(hdu_shape) == 2:
self.process_image(hdu_data, preds, tf_model)

return preds
if not self.sing_mask:
preds = np.transpose(preds, (2, 0, 1))

return preds

def process_image(self, im_data, preds, tf_model):
"""Processes 2D image data.
Expand Down Expand Up @@ -371,11 +378,6 @@ def process_image(self, im_data, preds, tf_model):
batch_coord_list = block_coord_list[-rest:]
self.process_batch(im_data, preds, tf_model, batch_coord_list)

if not self.sing_mask:
preds = np.transpose(preds, (2, 0, 1))

return preds

def get_block_coords(self, h, w):
"""Gets the coordinate list of blocks to process.
Expand Down

0 comments on commit 63c4c8f

Please sign in to comment.