Skip to content

Commit ea3347f

Browse files
committed
changes everything over to use PIL image for both mask and depth get_image. also cleans up lots of dataset code
1 parent 676922f commit ea3347f

File tree

3 files changed

+62
-170
lines changed

3 files changed

+62
-170
lines changed

dense_correspondence/correspondence_tools/correspondence_augmentation.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
import random
1616

1717

18-
def random_image_and_indices_mutation(image, uv_pixel_positions):
18+
def random_image_and_indices_mutation(images, uv_pixel_positions):
1919
"""
20-
This function takes an image and a list of pixel positions in the image,
20+
This function takes a list of images and a list of pixel positions in the image,
2121
and picks some subset of available mutations.
2222
23-
:param image: image for which to augment
24-
:type image: PIL.image.image
23+
:param images: a list of images (for example the rgb, depth, and mask) for which the
24+
**same** mutation will be applied
25+
:type images: list of PIL.image.image
2526
2627
:param uv_pixel_positions: pixel locations (u, v) in the image.
2728
See doc/coordinate_conventions.md for definition of (u, v)
@@ -35,34 +36,34 @@ def random_image_and_indices_mutation(image, uv_pixel_positions):
3536
Note: aim is to support both torch.LongTensor and torch.FloatTensor,
3637
and return the mutated_uv_pixel_positions with same type
3738
38-
:return mutated_image, mutated_uv_pixel_positions
39-
:rtype: PIL.image.image, tuple of torch Tensors
39+
:return mutated_image_list, mutated_uv_pixel_positions
40+
:rtype: list of PIL.image.image, tuple of torch Tensors
4041
4142
"""
42-
mutated_image, mutated_uv_pixel_positions = random_flip_vertical(image, uv_pixel_positions)
43-
mutated_image, mutated_uv_pixel_positions = random_flip_horizontal(mutated_image, mutated_uv_pixel_positions)
44-
return mutated_image, mutated_uv_pixel_positions
43+
mutated_images, mutated_uv_pixel_positions = random_flip_vertical(images, uv_pixel_positions)
44+
mutated_images, mutated_uv_pixel_positions = random_flip_horizontal(mutated_images, mutated_uv_pixel_positions)
45+
return mutated_images, mutated_uv_pixel_positions
4546

4647

47-
def random_flip_vertical(image, uv_pixel_positions):
48+
def random_flip_vertical(images, uv_pixel_positions):
4849
"""
49-
Randomly flip the image and the pixel positions vertically (flip up/down)
50+
Randomly flip the images and the pixel positions vertically (flip up/down)
5051
5152
See random_image_and_indices_mutation() for documentation of args and return types.
5253
5354
"""
5455

5556
if random.random() < 0.5:
56-
return image, uv_pixel_positions # Randomly do not apply
57+
return images, uv_pixel_positions # Randomly do not apply
5758

5859
print "Flip vertically"
59-
mutated_image = ImageOps.flip(image)
60+
mutated_images = [ImageOps.flip(image) for image in images]
6061
v_pixel_positions = uv_pixel_positions[1]
6162
mutated_v_pixel_positions = image.height - v_pixel_positions
6263
mutated_uv_pixel_positions = (uv_pixel_positions[0], mutated_v_pixel_positions)
63-
return mutated_image, mutated_uv_pixel_positions
64+
return mutated_images, mutated_uv_pixel_positions
6465

65-
def random_flip_horizontal(image, uv_pixel_positions):
66+
def random_flip_horizontal(images, uv_pixel_positions):
6667
"""
6768
Randomly flip the image and the pixel positions horizontall (flip left/right)
6869
@@ -71,11 +72,11 @@ def random_flip_horizontal(image, uv_pixel_positions):
7172
"""
7273

7374
if random.random() < 0.5:
74-
return image, uv_pixel_positions # Randomly do not apply
75+
return images, uv_pixel_positions # Randomly do not apply
7576

7677
print "Flip left and right"
77-
mutated_image = ImageOps.mirror(image)
78+
mutated_images = [ImageOps.mirror(image) for image in images]
7879
u_pixel_positions = uv_pixel_positions[0]
7980
mutated_u_pixel_positions = image.width - u_pixel_positions
8081
mutated_uv_pixel_positions = (mutated_u_pixel_positions, uv_pixel_positions[1])
81-
return mutated_image, mutated_uv_pixel_positions
82+
return mutated_images, mutated_uv_pixel_positions

dense_correspondence/dataset/dense_correspondence_dataset_masked.py

+19-107
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,23 @@ def __getitem__(self, index):
6161
dtype_long = torch.LongTensor
6262

6363
# pick a scene
64-
scene_directory = self.get_random_scene_directory()
6564
scene_name = self.get_random_scene_name()
6665

67-
68-
6966
# image a
70-
img_a_idx = self.get_random_image_index(scene_name, )
71-
image_a_rgb, image_a_depth, image_a_mask, image_a_pose = self.get_rgbd_mask_pose(scene_name, img_a_idx)
67+
image_a_idx = self.get_random_image_index(scene_name)
68+
image_a_rgb, image_a_depth, image_a_mask, image_a_pose = self.get_rgbd_mask_pose(scene_name, image_a_idx)
7269

7370
# image b
74-
img_b_idx = self.get_img_idx_with_different_pose(scene_name, image_a_pose, num_attempts=50)
71+
image_b_idx = self.get_img_idx_with_different_pose(scene_name, image_a_pose, num_attempts=50)
7572

76-
if img_b_idx is None:
73+
if image_b_idx is None:
7774
logging.info("no frame with sufficiently different pose found, returning")
7875
print "no frame with sufficiently different pose found, returning"
7976
return "matches", image_a_rgb, image_a_rgb, torch.zeros(1).type(dtype_long), torch.zeros(1).type(
8077
dtype_long), torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long)
8178

8279

83-
image_b_rgb, image_b_depth, image_b_mask, image_b_pose = self.get_rgbd_mask_pose(scene_name, img_b_idx)
80+
image_b_rgb, image_b_depth, image_b_mask, image_b_pose = self.get_rgbd_mask_pose(scene_name, image_b_idx)
8481

8582

8683
num_attempts = 50000
@@ -97,7 +94,6 @@ def __getitem__(self, index):
9794
image_b_depth_numpy, image_b_pose,
9895
num_attempts=num_attempts, img_a_mask=np.asarray(image_a_mask))
9996

100-
10197
if uv_a is None:
10298
print "No matches this time"
10399
return "matches", image_a_rgb, image_b_rgb, torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long), torch.zeros(1).type(dtype_long)
@@ -152,8 +148,6 @@ def __getitem__(self, index):
152148
if self.tensor_transform is not None:
153149
image_a_rgb, image_b_rgb = self.both_to_tensor([image_a_rgb, image_b_rgb])
154150

155-
156-
157151
uv_a_long = (torch.t(uv_a[0].repeat(num_non_matches_per_match, 1)).contiguous().view(-1,1),
158152
torch.t(uv_a[1].repeat(num_non_matches_per_match, 1)).contiguous().view(-1,1))
159153
uv_b_non_matches_long = (uv_b_non_matches[0].view(-1,1), uv_b_non_matches[1].view(-1,1) )
@@ -166,28 +160,17 @@ def __getitem__(self, index):
166160
uv_a_long = uv_a_long.squeeze(1)
167161
uv_b_non_matches_long = uv_b_non_matches_long.squeeze(1)
168162

169-
170163
return "matches", image_a_rgb, image_b_rgb, uv_a, uv_b, uv_a_long, uv_b_non_matches_long
171164

172-
def get_random_rgbd_with_pose(self, scene_directory):
173-
rgb_filename = self.get_random_rgb_image_filename(scene_directory)
174-
depth_filename = self.get_depth_filename(rgb_filename)
175-
176-
rgb = self.get_rgb_image(rgb_filename)
177-
depth = self.get_depth_image(depth_filename)
178-
pose = self.get_pose(rgb_filename)
179-
180-
return rgb, depth, pose
181-
182165
def get_rgbd_mask_pose(self, scene_name, img_idx):
183166
"""
184167
Returns rgb image, depth image, mask and pose.
185168
:param scene_name:
186169
:type scene_name: str
187170
:param img_idx:
188171
:type img_idx: int
189-
:return:
190-
:rtype:
172+
:return: rgb, depth, mask, pose
173+
:rtype: PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, a 4x4 numpy array
191174
"""
192175
rgb_file = self.get_image_filename(scene_name, img_idx, ImageType.RGB)
193176
rgb = self.get_rgb_image(rgb_file)
@@ -202,27 +185,6 @@ def get_rgbd_mask_pose(self, scene_name, img_idx):
202185

203186
return rgb, depth, mask, pose
204187

205-
def get_random_rgbd_with_pose_and_mask(self, scene_directory):
206-
rgb_filename = self.get_random_rgb_image_filename(scene_directory)
207-
depth_filename = self.get_depth_filename(rgb_filename)
208-
mask_filename = self.get_mask_filename(rgb_filename)
209-
210-
rgb = self.get_rgb_image(rgb_filename)
211-
depth = self.get_depth_image(depth_filename)
212-
pose = self.get_pose(rgb_filename)
213-
mask = self.get_mask_image(mask_filename)
214-
215-
return rgb, depth, pose, mask
216-
217-
def get_random_rgb_with_mask(self, scene_directory):
218-
rgb_filename = self.get_random_rgb_image_filename(scene_directory)
219-
mask_filename = self.get_mask_filename(rgb_filename)
220-
221-
rgb = self.get_rgb_image(rgb_filename)
222-
mask = self.get_mask_image(mask_filename)
223-
224-
return rgb, mask
225-
226188
def get_img_idx_with_different_pose(self, scene_name, pose_a, threshold=0.2, num_attempts=10):
227189
"""
228190
Try to get an image with a different pose to the one passed in. If one can't be found
@@ -235,8 +197,8 @@ def get_img_idx_with_different_pose(self, scene_name, pose_a, threshold=0.2, num
235197
:type threshold:
236198
:param num_attempts:
237199
:type num_attempts:
238-
:return:
239-
:rtype:
200+
:return: an index with a different-enough pose
201+
:rtype: int or None
240202
"""
241203

242204
counter = 0
@@ -252,41 +214,11 @@ def get_img_idx_with_different_pose(self, scene_name, pose_a, threshold=0.2, num
252214
return None
253215

254216

255-
def get_different_rgbd_with_pose(self, scene_directory, image_a_pose):
256-
# try to get a far-enough-away pose
257-
# if can't, then just return last sampled pose
258-
num_attempts = 0
259-
while num_attempts < 10:
260-
rgb_filename = self.get_random_rgb_image_filename(scene_directory)
261-
depth_filename = self.get_depth_filename(rgb_filename)
262-
pose = self.get_pose(rgb_filename)
263-
if self.different_enough(image_a_pose, pose):
264-
break
265-
num_attempts += 1
266-
267-
rgb = self.get_rgb_image(rgb_filename)
268-
depth = self.get_depth_image(depth_filename)
269-
return rgb, depth, pose
270-
271-
def get_different_rgbd_with_pose_and_mask(self, scene_directory, image_a_pose):
272-
# try to get a far-enough-away pose
273-
# if can't, then just return last sampled pose
274-
num_attempts = 0
275-
while num_attempts < 10:
276-
rgb_filename = self.get_random_rgb_image_filename(scene_directory)
277-
depth_filename = self.get_depth_filename(rgb_filename)
278-
pose = self.get_pose(rgb_filename)
279-
mask_filename = self.get_mask_filename(rgb_filename)
280-
if self.different_enough(image_a_pose, pose):
281-
break
282-
num_attempts += 1
283-
284-
rgb = self.get_rgb_image(rgb_filename)
285-
depth = self.get_depth_image(depth_filename)
286-
mask = self.get_mask_image(mask_filename)
287-
return rgb, depth, pose, mask
288-
289217
def get_rgb_image(self, rgb_filename):
218+
"""
219+
:param depth_filename: string of full path to depth image
220+
:return: PIL.Image.Image, in particular an 'RGB' PIL image
221+
"""
290222
return Image.open(rgb_filename).convert('RGB')
291223

292224
def get_rgb_image_from_scene_name_and_idx(self, scene_name, img_idx):
@@ -300,20 +232,17 @@ def get_rgb_image_from_scene_name_and_idx(self, scene_name, img_idx):
300232
return self.get_rgb_image(img_filename)
301233

302234
def get_depth_image(self, depth_filename):
303-
return Image.open(depth_filename)
304-
305-
def get_depth_image_from_scene_name_and_idx(self, scene_name, img_idx):
306235
"""
307-
Returns a depth image given a scene_name and image index
308-
:param scene_name:
309-
:param img_idx: str or int
236+
:param depth_filename: string of full path to depth image
310237
:return: PIL.Image.Image
311238
"""
312-
313-
img_filename = self.get_image_filename(scene_name, img_idx, ImageType.DEPTH)
314-
return self.get_depth_image(img_filename)
239+
return Image.open(depth_filename)
315240

316241
def get_mask_image(self, mask_filename):
242+
"""
243+
:param mask_filename: string of full path to mask image
244+
:return: PIL.Image.Image
245+
"""
317246
return Image.open(mask_filename)
318247

319248
def get_mask_image_from_scene_name_and_idx(self, scene_name, img_idx):
@@ -344,23 +273,6 @@ def get_random_rgb_image_filename(self, scene_directory):
344273
random_rgb_image = random.choice(all_rgb_images_in_scene)
345274
return random_rgb_image
346275

347-
def get_specific_rgbd_with_pose(self, scene_name, img_index):
348-
"""
349-
Returns a rgbd image along with the camera pose for a specific image
350-
in a specific scene
351-
:param scene_name:
352-
:param img_index:
353-
:return:
354-
"""
355-
rgb_filename = self.get_specific_rgb_image_filename(scene_name, img_index)
356-
depth_filename = self.get_depth_filename(rgb_filename)
357-
358-
rgb = self.get_rgb_image(rgb_filename)
359-
depth = self.get_depth_image(depth_filename)
360-
pose = self.get_pose(rgb_filename)
361-
362-
return rgb, depth, pose
363-
364276
def get_specific_rgb_image_filname(self, scene_name, img_index):
365277
DeprecationWarning("use get_specific_rgb_image_filename instead")
366278
return self.get_specific_rgb_image_filename(scene_name, img_index)

0 commit comments

Comments
 (0)