@@ -37,28 +37,53 @@ class ImageType:
37
37
38
38
class DenseCorrespondenceDataset (data .Dataset ):
39
39
40
- def __init__ (self , debug = False ):
40
+ def __init__ (self , debug = False , training_config = None ):
41
41
42
42
self .debug = debug
43
-
44
43
self .mode = "train"
45
-
46
44
self .both_to_tensor = ComposeJoint (
47
45
[
48
46
[transforms .ToTensor (), transforms .ToTensor ()]
49
47
])
48
+ if training_config == None :
49
+ self .num_matching_attempts = 50000
50
+ self .num_non_matches_per_match = 150
51
+ else :
52
+ self .num_matching_attempts = training_config ["num_matching_attempts" ]
53
+ self .num_non_matches_per_match = training_config ["num_non_matches_per_match" ]
50
54
51
- self .tensor_transform = ComposeJoint (
52
- [
53
- [transforms .ToTensor (), None ],
54
- [None , transforms .Lambda (lambda x : torch .from_numpy (x ).long ()) ]
55
- ])
55
+ if self .debug :
56
+ self .num_attempts = 20
57
+ self .num_non_matches_per_match = 1
56
58
57
59
def __len__ (self ):
58
60
return self .num_images_total
59
61
60
62
def __getitem__ (self , index ):
61
- dtype_long = torch .LongTensor
63
+ """
64
+ The method through which the dataset is accessed for training.
65
+
66
+ The index param is not currently used, and instead each dataset[i] is the result of
67
+ a random sampling over:
68
+ - random scene
69
+ - random rgbd frame from that scene
70
+ - random rgbd frame (different enough pose) from that scene
71
+ - various randomization in the match generation and non-match generation procedure
72
+
73
+ returns a large amount of variables, separated by commas.
74
+
75
+ 0th return arg: the type of data sampled (this can be used as a flag for different loss functions)
76
+ 0th rtype: string
77
+
78
+ 1st, 2nd return args: image_a_rgb, image_b_rgb
79
+ 1st, 2nd rtype: 3-dimensional torch.FloatTensor of shape (image_height, image_width, 3)
80
+
81
+ 3rd, 4th return args: matches_a, matches_b
82
+ 3rd, 4th rtype: 1-dimensional torch.LongTensor of shape (num_matches)
83
+
84
+ 5th, 6th return args: non_matches_a, non_matches_b
85
+ 5th, 6th rtype: 1-dimensional torch.LongTensor of shape (num_non_matches)
86
+ """
62
87
63
88
# pick a scene
64
89
scene_name = self .get_random_scene_name ()
@@ -72,31 +97,22 @@ def __getitem__(self, index):
72
97
73
98
if image_b_idx is None :
74
99
logging .info ("no frame with sufficiently different pose found, returning" )
75
- print "no frame with sufficiently different pose found, returning"
76
- return "matches" , image_a_rgb , image_a_rgb , torch .zeros (1 ).type (dtype_long ), torch .zeros (1 ).type (
77
- dtype_long ), torch .zeros (1 ).type (dtype_long ), torch .zeros (1 ).type (dtype_long )
78
-
100
+ # TODO: return something cleaner than no-data
101
+ return self .return_empty_data (image_a_rgb , image_b_rgb )
79
102
80
103
image_b_rgb , image_b_depth , image_b_mask , image_b_pose = self .get_rgbd_mask_pose (scene_name , image_b_idx )
81
104
82
-
83
- num_attempts = 50000
84
- num_non_matches_per_match = 150
85
- if self .debug :
86
- num_attempts = 20
87
- num_non_matches_per_match = 1
88
-
89
105
image_a_depth_numpy = np .asarray (image_a_depth )
90
106
image_b_depth_numpy = np .asarray (image_b_depth )
91
107
92
108
# find correspondences
93
109
uv_a , uv_b = correspondence_finder .batch_find_pixel_correspondences (image_a_depth_numpy , image_a_pose ,
94
110
image_b_depth_numpy , image_b_pose ,
95
- num_attempts = num_attempts , img_a_mask = np .asarray (image_a_mask ))
111
+ num_attempts = self . num_matching_attempts , img_a_mask = np .asarray (image_a_mask ))
96
112
97
113
if uv_a is None :
98
- print "No matches this time"
99
- return "matches" , image_a_rgb , image_b_rgb , torch . zeros ( 1 ). type ( dtype_long ), torch . zeros ( 1 ). type ( dtype_long ), torch . zeros ( 1 ). type ( dtype_long ), torch . zeros ( 1 ). type ( dtype_long )
114
+ logging . info ( "no matches found, returning" )
115
+ return self . return_empty_data ( image_a_rgb , image_b_rgb )
100
116
101
117
if self .debug :
102
118
# downsample so can plot
@@ -106,24 +122,25 @@ def __getitem__(self, index):
106
122
uv_b = (torch .index_select (uv_b [0 ], 0 , indexes_to_keep ), torch .index_select (uv_b [1 ], 0 , indexes_to_keep ))
107
123
108
124
# data augmentation
109
- if not self .debug :
110
- [image_a_rgb ], uv_a = correspondence_augmentation .random_image_and_indices_mutation ([image_a_rgb ], uv_a )
111
- [image_b_rgb , image_b_mask ], uv_b = correspondence_augmentation .random_image_and_indices_mutation ([image_b_rgb , image_b_mask ], uv_b )
112
- else : # also mutate depth just for plotting
113
- [image_a_rgb , image_a_depth ], uv_a = correspondence_augmentation .random_image_and_indices_mutation ([image_a_rgb , image_a_depth ], uv_a )
114
- [image_b_rgb , image_b_depth , image_b_mask ], uv_b = correspondence_augmentation .random_image_and_indices_mutation ([image_b_rgb , image_b_depth , image_b_mask ], uv_b )
115
- image_a_depth_numpy = np .asarray (image_a_depth )
116
- image_b_depth_numpy = np .asarray (image_b_depth )
125
+ # if not self.debug:
126
+ # [image_a_rgb], uv_a = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb], uv_a)
127
+ # [image_b_rgb, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_mask], uv_b)
128
+ # else: # also mutate depth just for plotting
129
+ # [image_a_rgb, image_a_depth], uv_a = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb, image_a_depth], uv_a)
130
+ # [image_b_rgb, image_b_depth, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_depth, image_b_mask], uv_b)
131
+ # image_a_depth_numpy = np.asarray(image_a_depth)
132
+ # image_b_depth_numpy = np.asarray(image_b_depth)
117
133
118
134
# find non_correspondences
135
+
119
136
if index % 2 :
120
137
print "masking non-matches"
121
138
image_b_mask = torch .from_numpy (np .asarray (image_b_mask )).type (torch .FloatTensor )
122
139
else :
123
140
print "not masking non-matches"
124
141
image_b_mask = None
125
142
126
- uv_b_non_matches = correspondence_finder .create_non_correspondences (uv_b , num_non_matches_per_match = num_non_matches_per_match , img_b_mask = image_b_mask )
143
+ uv_b_non_matches = correspondence_finder .create_non_correspondences (uv_b , num_non_matches_per_match = self . num_non_matches_per_match , img_b_mask = image_b_mask )
127
144
128
145
if self .debug :
129
146
# only want to bring in plotting code if in debug mode
@@ -132,8 +149,8 @@ def __getitem__(self, index):
132
149
# Just show all images
133
150
# self.debug_show_data(image_a_rgb, image_a_depth, image_b_pose,
134
151
# image_b_rgb, image_b_depth, image_b_pose)
135
- uv_a_long = (torch .t (uv_a [0 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ),
136
- torch .t (uv_a [1 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
152
+ uv_a_long = (torch .t (uv_a [0 ].repeat (self . num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ),
153
+ torch .t (uv_a [1 ].repeat (self . num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
137
154
uv_b_non_matches_long = (uv_b_non_matches [0 ].view (- 1 ,1 ), uv_b_non_matches [1 ].view (- 1 ,1 ) )
138
155
139
156
# Show correspondences
@@ -144,23 +161,24 @@ def __getitem__(self, index):
144
161
use_previous_plot = (fig ,axes ),
145
162
circ_color = 'r' )
146
163
164
+ image_a_rgb , image_b_rgb = self .both_to_tensor ([image_a_rgb , image_b_rgb ])
147
165
148
- if self .tensor_transform is not None :
149
- image_a_rgb , image_b_rgb = self .both_to_tensor ([image_a_rgb , image_b_rgb ])
150
-
151
- uv_a_long = (torch .t (uv_a [0 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ),
152
- torch .t (uv_a [1 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
166
+ uv_a_long = (torch .t (uv_a [0 ].repeat (self .num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ),
167
+ torch .t (uv_a [1 ].repeat (self .num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
153
168
uv_b_non_matches_long = (uv_b_non_matches [0 ].view (- 1 ,1 ), uv_b_non_matches [1 ].view (- 1 ,1 ) )
154
169
155
170
# flatten correspondences and non_correspondences
156
- uv_a = uv_a [1 ].type ( dtype_long )* 640 + uv_a [0 ].type ( dtype_long )
157
- uv_b = uv_b [1 ].type ( dtype_long )* 640 + uv_b [0 ].type ( dtype_long )
158
- uv_a_long = uv_a_long [1 ].type ( dtype_long )* 640 + uv_a_long [0 ].type ( dtype_long )
159
- uv_b_non_matches_long = uv_b_non_matches_long [ 1 ]. type ( dtype_long ) * 640 + uv_b_non_matches_long [ 0 ]. type ( dtype_long )
160
- uv_a_long = uv_a_long . squeeze ( 1 )
161
- uv_b_non_matches_long = uv_b_non_matches_long .squeeze (1 )
171
+ matches_a = uv_a [1 ].long ( )* 640 + uv_a [0 ].long ( )
172
+ matches_b = uv_b [1 ].long ( )* 640 + uv_b [0 ].long ( )
173
+ non_matches_a = uv_a_long [1 ].long ( )* 640 + uv_a_long [0 ].long ( )
174
+ non_matches_a = non_matches_a . squeeze ( 1 )
175
+ non_matches_b = uv_b_non_matches_long [ 1 ]. long () * 640 + uv_b_non_matches_long [ 0 ]. long ( )
176
+ non_matches_b = non_matches_b .squeeze (1 )
162
177
163
- return "matches" , image_a_rgb , image_b_rgb , uv_a , uv_b , uv_a_long , uv_b_non_matches_long
178
+ return "matches" , image_a_rgb , image_b_rgb , matches_a , matches_b , non_matches_a , non_matches_b
179
+
180
+ def return_empty_data (self , image_a_rgb , image_b_rgb ):
181
+ None , image_a_rgb , image_b_rgb , torch .zeros (1 ).long (), torch .zeros (1 ).long (), torch .zeros (1 ).long (), torch .zeros (1 ).long ()
164
182
165
183
def get_rgbd_mask_pose (self , scene_name , img_idx ):
166
184
"""
@@ -238,6 +256,16 @@ def get_depth_image(self, depth_filename):
238
256
"""
239
257
return Image .open (depth_filename )
240
258
259
+ def get_depth_image_from_scene_name_and_idx (self , scene_name , img_idx ):
260
+ """
261
+ Returns a depth image given a scene_name and image index
262
+ :param scene_name:
263
+ :param img_idx: str or int
264
+ :return: PIL.Image.Image
265
+ """
266
+ img_filename = self .get_image_filename (scene_name , img_idx , ImageType .DEPTH )
267
+ return self .get_depth_image (img_filename )
268
+
241
269
def get_mask_image (self , mask_filename ):
242
270
"""
243
271
:param mask_filename: string of full path to mask image
@@ -255,43 +283,6 @@ def get_mask_image_from_scene_name_and_idx(self, scene_name, img_idx):
255
283
img_filename = self .get_image_filename (scene_name , img_idx , ImageType .MASK )
256
284
return self .get_mask_image (img_filename )
257
285
258
-
259
- def different_enough (self , pose_1 , pose_2 ):
260
- translation_1 = np .asarray (pose_1 [0 ,3 ], pose_1 [1 ,3 ], pose_1 [2 ,3 ])
261
- translation_2 = np .asarray (pose_2 [0 ,3 ], pose_2 [1 ,3 ], pose_2 [2 ,3 ])
262
-
263
- translation_threshold = 0.2 # meters
264
- if np .linalg .norm (translation_1 - translation_2 ) > translation_threshold :
265
- return True
266
-
267
- # later implement something that is different_enough for rotations?
268
- return False
269
-
270
- def get_random_rgb_image_filename (self , scene_directory ):
271
- rgb_images_regex = os .path .join (scene_directory , "images/*_rgb.png" )
272
- all_rgb_images_in_scene = sorted (glob .glob (rgb_images_regex ))
273
- random_rgb_image = random .choice (all_rgb_images_in_scene )
274
- return random_rgb_image
275
-
276
- def get_specific_rgb_image_filname (self , scene_name , img_index ):
277
- DeprecationWarning ("use get_specific_rgb_image_filename instead" )
278
- return self .get_specific_rgb_image_filename (scene_name , img_index )
279
-
280
- def get_specific_rgb_image_filename (self , scene_name , img_index ):
281
- """
282
- Returns the filename for the specific RGB image
283
- :param scene_name:
284
- :param img_index: int or str
285
- :return:
286
- """
287
- if isinstance (img_index , int ):
288
- img_index = utils .getPaddedString (img_index )
289
-
290
- scene_directory = self .get_full_path_for_scene (scene_name )
291
- images_dir = os .path .join (scene_directory , "images" )
292
- rgb_image_filename = os .path .join (images_dir , img_index + "_rgb.png" )
293
- return rgb_image_filename
294
-
295
286
def get_image_filename (self , scene_name , img_index , image_type ):
296
287
raise NotImplementedError ("Implement in superclass" )
297
288
@@ -313,11 +304,6 @@ def get_pose_from_scene_name_and_idx(self, scene_name, idx):
313
304
"""
314
305
raise NotImplementedError ("subclass must implement this method" )
315
306
316
- def get_depth_filename (self , rgb_image ):
317
- prefix = rgb_image .split ("rgb" )[0 ]
318
- depth_filename = prefix + "depth.png"
319
- return depth_filename
320
-
321
307
# this function cowbody copied from:
322
308
# https://www.lfd.uci.edu/~gohlke/code/transformations.py.html
323
309
def quaternion_matrix (self , quaternion ):
0 commit comments