Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #98 - Enable the use of dataset_volume.py for unsupervised cases (label/mask = None) #101

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions connectomics/data/dataset/dataset_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,14 @@ def _process_targets(self, sample):
out_volume = normalize_image(out_volume, self.data_mean, self.data_std)

# output list
out_target = seg_to_targets(
out_label, self.target_opt, self.erosion_rates, self.dilation_rates)
out_weight = seg_to_weights(
out_target, self.weight_opt, out_valid, out_label)
return pos, out_volume, out_target, out_weight
if out_label is None:
return pos, out_volume, None, None
else:
out_target = self._seg_to_targets(
out_label, self.target_opt, self.erosion_rates, self.dilation_rates)
out_weight = self._seg_to_weights(
out_target, self.weight_opt, out_valid, out_label)
return pos, out_volume, out_target, out_weight

#######################################################
# Position Calculator
Expand Down Expand Up @@ -271,27 +274,29 @@ def _crop_with_pos(self, pos, vol_size):
out_volume = (crop_volume(
self.volume[pos[0]], vol_size, pos[1:])/255.0).astype(np.float32)
# position in the label and valid mask
pos_l = np.round(pos[1:]*self.label_vol_ratio)
out_label = crop_volume(
self.label[pos[0]], self.sample_label_size, pos_l)
# For warping: cv2.remap requires input to be float32.
# Make labels index smaller. Otherwise uint32 and float32 are not
# the same for some values.
out_label = relabel(out_label.copy()).astype(np.float32)

out_label = None
out_valid = None
if self.valid_mask is not None:
out_valid = crop_volume(self.label[pos[0]],
self.sample_label_size, pos_l)
out_valid = (out_valid != 0).astype(np.float32)
if self.label is not None:
pos_l = np.round(pos[1:]*self.label_vol_ratio)
out_label = self._crop_volume(
self.label[pos[0]], self.sample_label_size, pos_l)
# For warping: cv2.remap requires input to be float32.
# Make labels index smaller. Otherwise uint32 and float32 are not
# the same for some values.
out_label = self._relabel(out_label.copy()).astype(np.float32)

if self.valid_mask is not None:
out_valid = self._crop_volume(self.label[pos[0]],
self.sample_label_size, pos_l)
out_valid = (out_valid != 0).astype(np.float32)

return pos, out_volume, out_label, out_valid

def _is_valid(self, out_valid: np.ndarray) -> bool:
"""Decide whether the sampled region is valid or not using
the corresponding valid mask.
"""
if self.valid_mask is None:
if self.valid_mask is None or out_valid is None:
return True
ratio = float(out_valid.sum()) / np.prod(np.array(out_valid.shape))
return ratio > self.valid_ratio
Expand Down