@@ -102,6 +102,7 @@ def __init__(
102
102
if self .class_flags [k ]:
103
103
class_idx .append (k )
104
104
self .class_idx = np .array (class_idx , dtype = np .int32 )
105
+ self .effective_class_nb = np .sum (class_flags )
105
106
106
107
# builds the prior factor values to be given for inference
107
108
self .priors = priors
@@ -298,11 +299,11 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
298
299
preds = np .zeros_like (hdu_data , dtype = np .int16 )
299
300
elif self .thresholds is not None :
300
301
preds = np .zeros (
301
- list (hdu_shape ) + [np . sum ( self .class_flags ) ], dtype = np .uint8
302
+ list (hdu_shape ) + [self .effective_class_nb ], dtype = np .uint8
302
303
)
303
304
else :
304
305
preds = np .zeros (
305
- list (hdu_shape ) + [np . sum ( self .class_flags ) ], dtype = np .float32
306
+ list (hdu_shape ) + [self .effective_class_nb ], dtype = np .float32
306
307
)
307
308
308
309
# preprocessing
@@ -315,29 +316,35 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
315
316
# process the HDU 3D or 2D data
316
317
if len (hdu_shape ) == 3 :
317
318
c , h , w = hdu_shape
318
- for ch in tqdm .tqdm (range (c ), desc = "CUBE CHANNELS" ):
319
319
320
- # make temporary 2D prediction array to get results by reference
321
- if self .sing_mask :
322
- tmp_preds = np .zeros_like ([h , w ], dtype = np .int16 )
323
- elif self .thresholds is not None :
324
- tmp_preds = np .zeros (
325
- [h , w , np .sum (self .class_flags )], dtype = np .uint8
326
- )
327
- else :
328
- tmp_preds = np .zeros (
329
- [h , w , np .sum (self .class_flags )], dtype = np .float32
330
- )
331
-
332
- # make predictions and forward them to the final prediction array
320
+ # make temporary 2D prediction array to get results by reference for each channel individually
321
+ if self .sing_mask :
322
+ tmp_preds = np .zeros ([h , w ], dtype = np .int16 )
323
+ elif self .thresholds is not None :
324
+ tmp_preds = np .zeros (
325
+ [h , w , self .effective_class_nb ], dtype = np .uint8
326
+ )
327
+ else :
328
+ tmp_preds = np .zeros (
329
+ [h , w , self .effective_class_nb ], dtype = np .float32
330
+ )
331
+
332
+ # make predictions and forward them to the final prediction array
333
+ for ch in tqdm .tqdm (range (c ), desc = "CUBE CHANNELS" ):
333
334
ch_im_data = hdu_data [ch ]
334
335
self .process_image (ch_im_data , tmp_preds , tf_model )
335
336
preds [ch ] = tmp_preds
336
337
338
+ if not self .sing_mask :
339
+ preds = np .transpose (preds , (0 , 3 , 1 , 2 ))
340
+
337
341
elif len (hdu_shape ) == 2 :
338
342
self .process_image (hdu_data , preds , tf_model )
339
343
340
- return preds
344
+ if not self .sing_mask :
345
+ preds = np .transpose (preds , (2 , 0 , 1 ))
346
+
347
+ return preds
341
348
342
349
def process_image (self , im_data , preds , tf_model ):
343
350
"""Processes 2D image data.
@@ -371,11 +378,6 @@ def process_image(self, im_data, preds, tf_model):
371
378
batch_coord_list = block_coord_list [- rest :]
372
379
self .process_batch (im_data , preds , tf_model , batch_coord_list )
373
380
374
- if not self .sing_mask :
375
- preds = np .transpose (preds , (2 , 0 , 1 ))
376
-
377
- return preds
378
-
379
381
def get_block_coords (self , h , w ):
380
382
"""Gets the coordinate list of blocks to process.
381
383
0 commit comments