Skip to content

Commit ca1e8ba

Browse files
committed
significantly more dataset cleaning
1 parent 5475a57 commit ca1e8ba

File tree

1 file changed

+74
-88
lines changed

1 file changed

+74
-88
lines changed

dense_correspondence/dataset/dense_correspondence_dataset_masked.py

+74-88
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,53 @@ class ImageType:
3737

3838
class DenseCorrespondenceDataset(data.Dataset):
3939

40-
def __init__(self, debug=False):
40+
def __init__(self, debug=False, training_config=None):
4141

4242
self.debug = debug
43-
4443
self.mode = "train"
45-
4644
self.both_to_tensor = ComposeJoint(
4745
[
4846
[transforms.ToTensor(), transforms.ToTensor()]
4947
])
48+
if training_config==None:
49+
self.num_matching_attempts = 50000
50+
self.num_non_matches_per_match = 150
51+
else:
52+
self.num_matching_attempts = training_config["num_matching_attempts"]
53+
self.num_non_matches_per_match = training_config["num_non_matches_per_match"]
5054

51-
self.tensor_transform = ComposeJoint(
52-
[
53-
[transforms.ToTensor(), None],
54-
[None, transforms.Lambda(lambda x: torch.from_numpy(x).long()) ]
55-
])
55+
if self.debug:
56+
self.num_attempts = 20
57+
self.num_non_matches_per_match = 1
5658

5759
def __len__(self):
5860
return self.num_images_total
5961

6062
def __getitem__(self, index):
61-
dtype_long = torch.LongTensor
63+
"""
64+
The method through which the dataset is accessed for training.
65+
66+
The index param is not currently used, and instead each dataset[i] is the result of
67+
a random sampling over:
68+
- random scene
69+
- random rgbd frame from that scene
70+
- random rgbd frame (different enough pose) from that scene
71+
- various randomization in the match generation and non-match generation procedure
72+
73+
returns a large amount of variables, separated by commas.
74+
75+
0th return arg: the type of data sampled (this can be used as a flag for different loss functions)
76+
0th rtype: string
77+
78+
1st, 2nd return args: image_a_rgb, image_b_rgb
79+
1st, 2nd rtype: 3-dimensional torch.FloatTensor of shape (image_height, image_width, 3)
80+
81+
3rd, 4th return args: matches_a, matches_b
82+
3rd, 4th rtype: 1-dimensional torch.LongTensor of shape (num_matches)
83+
84+
5th, 6th return args: non_matches_a, non_matches_b
85+
5th, 6th rtype: 1-dimensional torch.LongTensor of shape (num_non_matches)
86+
"""
6287

6388
# pick a scene
6489
scene_name = self.get_random_scene_name()
@@ -72,31 +97,22 @@ def __getitem__(self, index):
7297

7398
if image_b_idx is None:
7499
logging.info("no frame with sufficiently different pose found, returning")
75-
print "no frame with sufficiently different pose found, returning"
76-
return "matches", image_a_rgb, image_a_rgb, torch.zeros(1).type(dtype_long), torch.zeros(1).type(
77-
dtype_long), torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long)
78-
100+
# TODO: return something cleaner than no-data
101+
return self.return_empty_data(image_a_rgb, image_b_rgb)
79102

80103
image_b_rgb, image_b_depth, image_b_mask, image_b_pose = self.get_rgbd_mask_pose(scene_name, image_b_idx)
81104

82-
83-
num_attempts = 50000
84-
num_non_matches_per_match = 150
85-
if self.debug:
86-
num_attempts = 20
87-
num_non_matches_per_match = 1
88-
89105
image_a_depth_numpy = np.asarray(image_a_depth)
90106
image_b_depth_numpy = np.asarray(image_b_depth)
91107

92108
# find correspondences
93109
uv_a, uv_b = correspondence_finder.batch_find_pixel_correspondences(image_a_depth_numpy, image_a_pose,
94110
image_b_depth_numpy, image_b_pose,
95-
num_attempts=num_attempts, img_a_mask=np.asarray(image_a_mask))
111+
num_attempts=self.num_matching_attempts, img_a_mask=np.asarray(image_a_mask))
96112

97113
if uv_a is None:
98-
print "No matches this time"
99-
return "matches", image_a_rgb, image_b_rgb, torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long)
114+
logging.info("no matches found, returning")
115+
return self.return_empty_data(image_a_rgb, image_b_rgb)
100116

101117
if self.debug:
102118
# downsample so can plot
@@ -106,24 +122,25 @@ def __getitem__(self, index):
106122
uv_b = (torch.index_select(uv_b[0], 0, indexes_to_keep), torch.index_select(uv_b[1], 0, indexes_to_keep))
107123

108124
# data augmentation
109-
if not self.debug:
110-
[image_a_rgb], uv_a = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb], uv_a)
111-
[image_b_rgb, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_mask], uv_b)
112-
else: # also mutate depth just for plotting
113-
[image_a_rgb, image_a_depth], uv_a = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb, image_a_depth], uv_a)
114-
[image_b_rgb, image_b_depth, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_depth, image_b_mask], uv_b)
115-
image_a_depth_numpy = np.asarray(image_a_depth)
116-
image_b_depth_numpy = np.asarray(image_b_depth)
125+
# if not self.debug:
126+
# [image_a_rgb], uv_a = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb], uv_a)
127+
# [image_b_rgb, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_mask], uv_b)
128+
# else: # also mutate depth just for plotting
129+
# [image_a_rgb, image_a_depth], uv_a = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb, image_a_depth], uv_a)
130+
# [image_b_rgb, image_b_depth, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_depth, image_b_mask], uv_b)
131+
# image_a_depth_numpy = np.asarray(image_a_depth)
132+
# image_b_depth_numpy = np.asarray(image_b_depth)
117133

118134
# find non_correspondences
135+
119136
if index%2:
120137
print "masking non-matches"
121138
image_b_mask = torch.from_numpy(np.asarray(image_b_mask)).type(torch.FloatTensor)
122139
else:
123140
print "not masking non-matches"
124141
image_b_mask = None
125142

126-
uv_b_non_matches = correspondence_finder.create_non_correspondences(uv_b, num_non_matches_per_match=num_non_matches_per_match, img_b_mask=image_b_mask)
143+
uv_b_non_matches = correspondence_finder.create_non_correspondences(uv_b, num_non_matches_per_match=self.num_non_matches_per_match, img_b_mask=image_b_mask)
127144

128145
if self.debug:
129146
# only want to bring in plotting code if in debug mode
@@ -132,8 +149,8 @@ def __getitem__(self, index):
132149
# Just show all images
133150
# self.debug_show_data(image_a_rgb, image_a_depth, image_b_pose,
134151
# image_b_rgb, image_b_depth, image_b_pose)
135-
uv_a_long = (torch.t(uv_a[0].repeat(num_non_matches_per_match, 1)).contiguous().view(-1,1),
136-
torch.t(uv_a[1].repeat(num_non_matches_per_match, 1)).contiguous().view(-1,1))
152+
uv_a_long = (torch.t(uv_a[0].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1),
153+
torch.t(uv_a[1].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1))
137154
uv_b_non_matches_long = (uv_b_non_matches[0].view(-1,1), uv_b_non_matches[1].view(-1,1) )
138155

139156
# Show correspondences
@@ -144,23 +161,24 @@ def __getitem__(self, index):
144161
use_previous_plot=(fig,axes),
145162
circ_color='r')
146163

164+
image_a_rgb, image_b_rgb = self.both_to_tensor([image_a_rgb, image_b_rgb])
147165

148-
if self.tensor_transform is not None:
149-
image_a_rgb, image_b_rgb = self.both_to_tensor([image_a_rgb, image_b_rgb])
150-
151-
uv_a_long = (torch.t(uv_a[0].repeat(num_non_matches_per_match, 1)).contiguous().view(-1,1),
152-
torch.t(uv_a[1].repeat(num_non_matches_per_match, 1)).contiguous().view(-1,1))
166+
uv_a_long = (torch.t(uv_a[0].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1),
167+
torch.t(uv_a[1].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1))
153168
uv_b_non_matches_long = (uv_b_non_matches[0].view(-1,1), uv_b_non_matches[1].view(-1,1) )
154169

155170
# flatten correspondences and non_correspondences
156-
uv_a = uv_a[1].type(dtype_long)*640+uv_a[0].type(dtype_long)
157-
uv_b = uv_b[1].type(dtype_long)*640+uv_b[0].type(dtype_long)
158-
uv_a_long = uv_a_long[1].type(dtype_long)*640+uv_a_long[0].type(dtype_long)
159-
uv_b_non_matches_long = uv_b_non_matches_long[1].type(dtype_long)*640+uv_b_non_matches_long[0].type(dtype_long)
160-
uv_a_long = uv_a_long.squeeze(1)
161-
uv_b_non_matches_long = uv_b_non_matches_long.squeeze(1)
171+
matches_a = uv_a[1].long()*640+uv_a[0].long()
172+
matches_b = uv_b[1].long()*640+uv_b[0].long()
173+
non_matches_a = uv_a_long[1].long()*640+uv_a_long[0].long()
174+
non_matches_a = non_matches_a.squeeze(1)
175+
non_matches_b = uv_b_non_matches_long[1].long()*640+uv_b_non_matches_long[0].long()
176+
non_matches_b = non_matches_b.squeeze(1)
162177

163-
return "matches", image_a_rgb, image_b_rgb, uv_a, uv_b, uv_a_long, uv_b_non_matches_long
178+
return "matches", image_a_rgb, image_b_rgb, matches_a, matches_b, non_matches_a, non_matches_b
179+
180+
def return_empty_data(self, image_a_rgb, image_b_rgb):
181+
None, image_a_rgb, image_b_rgb, torch.zeros(1).long(), torch.zeros(1).long(), torch.zeros(1).long(), torch.zeros(1).long()
164182

165183
def get_rgbd_mask_pose(self, scene_name, img_idx):
166184
"""
@@ -238,6 +256,16 @@ def get_depth_image(self, depth_filename):
238256
"""
239257
return Image.open(depth_filename)
240258

259+
def get_depth_image_from_scene_name_and_idx(self, scene_name, img_idx):
260+
"""
261+
Returns a depth image given a scene_name and image index
262+
:param scene_name:
263+
:param img_idx: str or int
264+
:return: PIL.Image.Image
265+
"""
266+
img_filename = self.get_image_filename(scene_name, img_idx, ImageType.DEPTH)
267+
return self.get_depth_image(img_filename)
268+
241269
def get_mask_image(self, mask_filename):
242270
"""
243271
:param mask_filename: string of full path to mask image
@@ -255,43 +283,6 @@ def get_mask_image_from_scene_name_and_idx(self, scene_name, img_idx):
255283
img_filename = self.get_image_filename(scene_name, img_idx, ImageType.MASK)
256284
return self.get_mask_image(img_filename)
257285

258-
259-
def different_enough(self, pose_1, pose_2):
260-
translation_1 = np.asarray(pose_1[0,3], pose_1[1,3], pose_1[2,3])
261-
translation_2 = np.asarray(pose_2[0,3], pose_2[1,3], pose_2[2,3])
262-
263-
translation_threshold = 0.2 # meters
264-
if np.linalg.norm(translation_1 - translation_2) > translation_threshold:
265-
return True
266-
267-
# later implement something that is different_enough for rotations?
268-
return False
269-
270-
def get_random_rgb_image_filename(self, scene_directory):
271-
rgb_images_regex = os.path.join(scene_directory, "images/*_rgb.png")
272-
all_rgb_images_in_scene = sorted(glob.glob(rgb_images_regex))
273-
random_rgb_image = random.choice(all_rgb_images_in_scene)
274-
return random_rgb_image
275-
276-
def get_specific_rgb_image_filname(self, scene_name, img_index):
277-
DeprecationWarning("use get_specific_rgb_image_filename instead")
278-
return self.get_specific_rgb_image_filename(scene_name, img_index)
279-
280-
def get_specific_rgb_image_filename(self, scene_name, img_index):
281-
"""
282-
Returns the filename for the specific RGB image
283-
:param scene_name:
284-
:param img_index: int or str
285-
:return:
286-
"""
287-
if isinstance(img_index, int):
288-
img_index = utils.getPaddedString(img_index)
289-
290-
scene_directory = self.get_full_path_for_scene(scene_name)
291-
images_dir = os.path.join(scene_directory, "images")
292-
rgb_image_filename = os.path.join(images_dir, img_index + "_rgb.png")
293-
return rgb_image_filename
294-
295286
def get_image_filename(self, scene_name, img_index, image_type):
296287
raise NotImplementedError("Implement in superclass")
297288

@@ -313,11 +304,6 @@ def get_pose_from_scene_name_and_idx(self, scene_name, idx):
313304
"""
314305
raise NotImplementedError("subclass must implement this method")
315306

316-
def get_depth_filename(self, rgb_image):
317-
prefix = rgb_image.split("rgb")[0]
318-
depth_filename = prefix+"depth.png"
319-
return depth_filename
320-
321307
# this function cowbody copied from:
322308
# https://www.lfd.uci.edu/~gohlke/code/transformations.py.html
323309
def quaternion_matrix(self, quaternion):

0 commit comments

Comments
 (0)