From 3df3fa03ecaa826e73c6cdc10c5d466b170396f2 Mon Sep 17 00:00:00 2001 From: Leander Lauenburg Date: Sat, 29 Jan 2022 08:13:42 -0500 Subject: [PATCH] added multiple switches that ensure the functionality of the volume data loader in case of labels and/or valid_mask is set to None --- connectomics/data/dataset/dataset_volume.py | 41 ++++++++++++--------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/connectomics/data/dataset/dataset_volume.py b/connectomics/data/dataset/dataset_volume.py index 9975602d..7d286b36 100755 --- a/connectomics/data/dataset/dataset_volume.py +++ b/connectomics/data/dataset/dataset_volume.py @@ -179,11 +179,14 @@ def _process_targets(self, sample): out_volume = normalize_image(out_volume, self.data_mean, self.data_std) # output list - out_target = seg_to_targets( - out_label, self.target_opt, self.erosion_rates, self.dilation_rates) - out_weight = seg_to_weights( - out_target, self.weight_opt, out_valid, out_label) - return pos, out_volume, out_target, out_weight + if out_label is None: + return pos, out_volume, None, None + else: + out_target = self._seg_to_targets( + out_label, self.target_opt, self.erosion_rates, self.dilation_rates) + out_weight = self._seg_to_weights( + out_target, self.weight_opt, out_valid, out_label) + return pos, out_volume, out_target, out_weight ####################################################### # Position Calculator @@ -271,19 +274,21 @@ def _crop_with_pos(self, pos, vol_size): out_volume = (crop_volume( self.volume[pos[0]], vol_size, pos[1:])/255.0).astype(np.float32) # position in the label and valid mask - pos_l = np.round(pos[1:]*self.label_vol_ratio) - out_label = crop_volume( - self.label[pos[0]], self.sample_label_size, pos_l) - # For warping: cv2.remap requires input to be float32. - # Make labels index smaller. Otherwise uint32 and float32 are not - # the same for some values. - out_label = relabel(out_label.copy()).astype(np.float32) - + out_label = None out_valid = None - if self.valid_mask is not None: - out_valid = crop_volume(self.label[pos[0]], - self.sample_label_size, pos_l) - out_valid = (out_valid != 0).astype(np.float32) + if self.label is not None: + pos_l = np.round(pos[1:]*self.label_vol_ratio) + out_label = self._crop_volume( + self.label[pos[0]], self.sample_label_size, pos_l) + # For warping: cv2.remap requires input to be float32. + # Make labels index smaller. Otherwise uint32 and float32 are not + # the same for some values. + out_label = self._relabel(out_label.copy()).astype(np.float32) + + if self.valid_mask is not None: + out_valid = self._crop_volume(self.label[pos[0]], + self.sample_label_size, pos_l) + out_valid = (out_valid != 0).astype(np.float32) return pos, out_volume, out_label, out_valid @@ -291,7 +296,7 @@ def _is_valid(self, out_valid: np.ndarray) -> bool: """Decide whether the sampled region is valid or not using the corresponding valid mask. """ - if self.valid_mask is None: + if self.valid_mask is None or out_valid is None: return True ratio = float(out_valid.sum()) / np.prod(np.array(out_valid.shape)) return ratio > self.valid_ratio