@@ -37,28 +37,53 @@ class ImageType:
3737
3838class  DenseCorrespondenceDataset (data .Dataset ):
3939
40-     def  __init__ (self , debug = False ):
40+     def  __init__ (self , debug = False ,  training_config = None ):
4141
4242        self .debug  =  debug 
43- 
4443        self .mode  =  "train" 
45-         
4644        self .both_to_tensor  =  ComposeJoint (
4745            [
4846                [transforms .ToTensor (), transforms .ToTensor ()]
4947            ])
48+         if  training_config == None :
49+             self .num_matching_attempts      =  50000 
50+             self .num_non_matches_per_match  =  150 
51+         else :
52+             self .num_matching_attempts  =  training_config ["num_matching_attempts" ]
53+             self .num_non_matches_per_match  =  training_config ["num_non_matches_per_match" ]
5054
51-         self .tensor_transform  =  ComposeJoint (
52-                 [
53-                     [transforms .ToTensor (), None ],
54-                     [None , transforms .Lambda (lambda  x : torch .from_numpy (x ).long ()) ]
55-                 ])
55+         if  self .debug :
56+             self .num_attempts  =  20 
57+             self .num_non_matches_per_match  =  1 
5658
5759    def  __len__ (self ):
5860        return  self .num_images_total 
5961
6062    def  __getitem__ (self , index ):
61-         dtype_long  =  torch .LongTensor 
63+         """ 
64+         The method through which the dataset is accessed for training. 
65+ 
66+         The index param is not currently used, and instead each dataset[i] is the result of 
67+         a random sampling over: 
68+         - random scene 
69+         - random rgbd frame from that scene 
70+         - random rgbd frame (different enough pose) from that scene 
71+         - various randomization in the match generation and non-match generation procedure 
72+ 
73+         returns a large amount of variables, separated by commas. 
74+ 
75+         0th return arg: the type of data sampled (this can be used as a flag for different loss functions) 
76+         0th rtype: string 
77+ 
78+         1st, 2nd return args: image_a_rgb, image_b_rgb 
79+         1st, 2nd rtype: 3-dimensional torch.FloatTensor of shape (image_height, image_width, 3) 
80+ 
81+         3rd, 4th return args: matches_a, matches_b 
82+         3rd, 4th rtype: 1-dimensional torch.LongTensor of shape (num_matches) 
83+ 
84+         5th, 6th return args: non_matches_a, non_matches_b 
85+         5th, 6th rtype: 1-dimensional torch.LongTensor of shape (num_non_matches)         
86+         """ 
6287
6388        # pick a scene 
6489        scene_name  =  self .get_random_scene_name ()
@@ -72,31 +97,22 @@ def __getitem__(self, index):
7297
7398        if  image_b_idx  is  None :
7499            logging .info ("no frame with sufficiently different pose found, returning" )
75-             print  "no frame with sufficiently different pose found, returning" 
76-             return  "matches" , image_a_rgb , image_a_rgb , torch .zeros (1 ).type (dtype_long ), torch .zeros (1 ).type (
77-                 dtype_long ), torch .zeros (1 ).type (dtype_long ), torch .zeros (1 ).type (dtype_long )
78- 
100+             # TODO: return something cleaner than no-data 
101+             return  self .return_empty_data (image_a_rgb , image_b_rgb )
79102
80103        image_b_rgb , image_b_depth , image_b_mask , image_b_pose  =  self .get_rgbd_mask_pose (scene_name , image_b_idx )
81104
82- 
83-         num_attempts  =  50000 
84-         num_non_matches_per_match  =  150 
85-         if  self .debug :
86-             num_attempts  =  20 
87-             num_non_matches_per_match  =  1 
88- 
89105        image_a_depth_numpy  =  np .asarray (image_a_depth )
90106        image_b_depth_numpy  =  np .asarray (image_b_depth )
91107
92108        # find correspondences 
93109        uv_a , uv_b  =  correspondence_finder .batch_find_pixel_correspondences (image_a_depth_numpy , image_a_pose , 
94110                                                                           image_b_depth_numpy , image_b_pose , 
95-                                                                            num_attempts = num_attempts , img_a_mask = np .asarray (image_a_mask ))
111+                                                                            num_attempts = self . num_matching_attempts , img_a_mask = np .asarray (image_a_mask ))
96112
97113        if  uv_a  is  None :
98-             print   "No  matches this time" 
99-             return  "matches" ,  image_a_rgb ,  image_b_rgb ,  torch . zeros ( 1 ). type ( dtype_long ),  torch . zeros ( 1 ). type ( dtype_long ),  torch . zeros ( 1 ). type ( dtype_long ),  torch . zeros ( 1 ). type ( dtype_long )
114+             logging . info ( "no  matches found, returning" ) 
115+             return  self . return_empty_data ( image_a_rgb ,  image_b_rgb )
100116
101117        if  self .debug :
102118            # downsample so can plot 
@@ -106,24 +122,25 @@ def __getitem__(self, index):
106122            uv_b  =  (torch .index_select (uv_b [0 ], 0 , indexes_to_keep ), torch .index_select (uv_b [1 ], 0 , indexes_to_keep ))
107123
108124        # data augmentation 
109-         if  not  self .debug :
110-             [image_a_rgb ], uv_a                  =  correspondence_augmentation .random_image_and_indices_mutation ([image_a_rgb ], uv_a )
111-             [image_b_rgb , image_b_mask ], uv_b    =  correspondence_augmentation .random_image_and_indices_mutation ([image_b_rgb , image_b_mask ], uv_b )
112-         else : # also mutate depth just for plotting 
113-             [image_a_rgb , image_a_depth ], uv_a                =  correspondence_augmentation .random_image_and_indices_mutation ([image_a_rgb , image_a_depth ], uv_a )
114-             [image_b_rgb , image_b_depth , image_b_mask ], uv_b  =  correspondence_augmentation .random_image_and_indices_mutation ([image_b_rgb , image_b_depth , image_b_mask ], uv_b )
115-             image_a_depth_numpy  =  np .asarray (image_a_depth )
116-             image_b_depth_numpy  =  np .asarray (image_b_depth )
125+         #  if not self.debug:
126+         #      [image_a_rgb], uv_a                 = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb], uv_a)
127+         #      [image_b_rgb, image_b_mask], uv_b   = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_mask], uv_b)
128+         #  else: # also mutate depth just for plotting
129+         #      [image_a_rgb, image_a_depth], uv_a               = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb, image_a_depth], uv_a)
130+         #      [image_b_rgb, image_b_depth, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_depth, image_b_mask], uv_b)
131+         #      image_a_depth_numpy = np.asarray(image_a_depth)
132+         #      image_b_depth_numpy = np.asarray(image_b_depth)
117133
118134        # find non_correspondences 
135+ 
119136        if  index % 2 :
120137            print  "masking non-matches" 
121138            image_b_mask  =  torch .from_numpy (np .asarray (image_b_mask )).type (torch .FloatTensor )
122139        else :
123140            print  "not masking non-matches" 
124141            image_b_mask  =  None 
125142
126-         uv_b_non_matches  =  correspondence_finder .create_non_correspondences (uv_b , num_non_matches_per_match = num_non_matches_per_match , img_b_mask = image_b_mask )
143+         uv_b_non_matches  =  correspondence_finder .create_non_correspondences (uv_b , num_non_matches_per_match = self . num_non_matches_per_match , img_b_mask = image_b_mask )
127144
128145        if  self .debug :
129146            # only want to bring in plotting code if in debug mode 
@@ -132,8 +149,8 @@ def __getitem__(self, index):
132149            # Just show all images  
133150            # self.debug_show_data(image_a_rgb, image_a_depth, image_b_pose, 
134151            #                  image_b_rgb, image_b_depth, image_b_pose) 
135-             uv_a_long  =  (torch .t (uv_a [0 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ), 
136-                      torch .t (uv_a [1 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
152+             uv_a_long  =  (torch .t (uv_a [0 ].repeat (self . num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ), 
153+                      torch .t (uv_a [1 ].repeat (self . num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
137154            uv_b_non_matches_long  =  (uv_b_non_matches [0 ].view (- 1 ,1 ), uv_b_non_matches [1 ].view (- 1 ,1 ) )
138155
139156            # Show correspondences 
@@ -144,23 +161,24 @@ def __getitem__(self, index):
144161                                                  use_previous_plot = (fig ,axes ),
145162                                                  circ_color = 'r' )
146163
164+         image_a_rgb , image_b_rgb  =  self .both_to_tensor ([image_a_rgb , image_b_rgb ])
147165
148-         if  self .tensor_transform  is  not   None :
149-             image_a_rgb , image_b_rgb  =  self .both_to_tensor ([image_a_rgb , image_b_rgb ])
150- 
151-         uv_a_long  =  (torch .t (uv_a [0 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ), 
152-                      torch .t (uv_a [1 ].repeat (num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
166+         uv_a_long  =  (torch .t (uv_a [0 ].repeat (self .num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ), 
167+                      torch .t (uv_a [1 ].repeat (self .num_non_matches_per_match , 1 )).contiguous ().view (- 1 ,1 ))
153168        uv_b_non_matches_long  =  (uv_b_non_matches [0 ].view (- 1 ,1 ), uv_b_non_matches [1 ].view (- 1 ,1 ) )
154169
155170        # flatten correspondences and non_correspondences 
156-         uv_a  =  uv_a [1 ].type ( dtype_long )* 640 + uv_a [0 ].type ( dtype_long )
157-         uv_b  =  uv_b [1 ].type ( dtype_long )* 640 + uv_b [0 ].type ( dtype_long )
158-         uv_a_long  =  uv_a_long [1 ].type ( dtype_long )* 640 + uv_a_long [0 ].type ( dtype_long )
159-         uv_b_non_matches_long  =  uv_b_non_matches_long [ 1 ]. type ( dtype_long ) * 640 + uv_b_non_matches_long [ 0 ]. type ( dtype_long )
160-         uv_a_long  =  uv_a_long . squeeze ( 1 )
161-         uv_b_non_matches_long  =  uv_b_non_matches_long .squeeze (1 )
171+         matches_a  =  uv_a [1 ].long ( )* 640 + uv_a [0 ].long ( )
172+         matches_b  =  uv_b [1 ].long ( )* 640 + uv_b [0 ].long ( )
173+         non_matches_a  =  uv_a_long [1 ].long ( )* 640 + uv_a_long [0 ].long ( )
174+         non_matches_a  =  non_matches_a . squeeze ( 1 )
175+         non_matches_b  =  uv_b_non_matches_long [ 1 ]. long () * 640 + uv_b_non_matches_long [ 0 ]. long ( )
176+         non_matches_b  =  non_matches_b .squeeze (1 )
162177
163-         return  "matches" , image_a_rgb , image_b_rgb , uv_a , uv_b , uv_a_long , uv_b_non_matches_long 
178+         return  "matches" , image_a_rgb , image_b_rgb , matches_a , matches_b , non_matches_a , non_matches_b 
179+ 
180+     def  return_empty_data (self , image_a_rgb , image_b_rgb ):
181+         None , image_a_rgb , image_b_rgb , torch .zeros (1 ).long (), torch .zeros (1 ).long (), torch .zeros (1 ).long (), torch .zeros (1 ).long ()
164182
165183    def  get_rgbd_mask_pose (self , scene_name , img_idx ):
166184        """ 
@@ -238,6 +256,16 @@ def get_depth_image(self, depth_filename):
238256        """ 
239257        return  Image .open (depth_filename )
240258
259+     def  get_depth_image_from_scene_name_and_idx (self , scene_name , img_idx ):
260+         """ 
261+         Returns a depth image given a scene_name and image index 
262+         :param scene_name: 
263+         :param img_idx: str or int 
264+         :return: PIL.Image.Image 
265+         """ 
266+         img_filename  =  self .get_image_filename (scene_name , img_idx , ImageType .DEPTH )
267+         return  self .get_depth_image (img_filename )
268+ 
241269    def  get_mask_image (self , mask_filename ):
242270        """ 
243271        :param mask_filename: string of full path to mask image 
@@ -255,43 +283,6 @@ def get_mask_image_from_scene_name_and_idx(self, scene_name, img_idx):
255283        img_filename  =  self .get_image_filename (scene_name , img_idx , ImageType .MASK )
256284        return  self .get_mask_image (img_filename )
257285
258- 
259-     def  different_enough (self , pose_1 , pose_2 ):
260-         translation_1  =  np .asarray (pose_1 [0 ,3 ], pose_1 [1 ,3 ], pose_1 [2 ,3 ])
261-         translation_2  =  np .asarray (pose_2 [0 ,3 ], pose_2 [1 ,3 ], pose_2 [2 ,3 ])
262- 
263-         translation_threshold  =  0.2  # meters 
264-         if  np .linalg .norm (translation_1  -  translation_2 ) >  translation_threshold :
265-             return  True 
266- 
267-         # later implement something that is different_enough for rotations? 
268-         return  False 
269- 
270-     def  get_random_rgb_image_filename (self , scene_directory ):
271-         rgb_images_regex  =  os .path .join (scene_directory , "images/*_rgb.png" )
272-         all_rgb_images_in_scene  =  sorted (glob .glob (rgb_images_regex ))
273-         random_rgb_image  =  random .choice (all_rgb_images_in_scene )
274-         return  random_rgb_image 
275- 
276-     def  get_specific_rgb_image_filname (self , scene_name , img_index ):
277-         DeprecationWarning ("use get_specific_rgb_image_filename instead" )
278-         return  self .get_specific_rgb_image_filename (scene_name , img_index )
279- 
280-     def  get_specific_rgb_image_filename (self , scene_name , img_index ):
281-         """ 
282-         Returns the filename for the specific RGB image 
283-         :param scene_name: 
284-         :param img_index: int or str 
285-         :return: 
286-         """ 
287-         if  isinstance (img_index , int ):
288-             img_index  =  utils .getPaddedString (img_index )
289- 
290-         scene_directory  =  self .get_full_path_for_scene (scene_name )
291-         images_dir  =  os .path .join (scene_directory , "images" )
292-         rgb_image_filename  =  os .path .join (images_dir , img_index  +  "_rgb.png" )
293-         return  rgb_image_filename 
294- 
295286    def  get_image_filename (self , scene_name , img_index , image_type ):
296287        raise  NotImplementedError ("Implement in superclass" )
297288
@@ -313,11 +304,6 @@ def get_pose_from_scene_name_and_idx(self, scene_name, idx):
313304        """ 
314305        raise  NotImplementedError ("subclass must implement this method" )
315306
316-     def  get_depth_filename (self , rgb_image ):
317-         prefix  =  rgb_image .split ("rgb" )[0 ]
318-         depth_filename  =  prefix + "depth.png" 
319-         return  depth_filename 
320- 
321307    # this function cowbody copied from: 
322308    # https://www.lfd.uci.edu/~gohlke/code/transformations.py.html 
323309    def  quaternion_matrix (self , quaternion ):
0 commit comments