10
10
blur1 = np .ones ((1 ,3 ,1 )).astype ('float32' )/ 3
11
11
blur2 = np .ones ((1 ,1 ,3 )).astype ('float32' )/ 3
12
12
13
+ ## Elastic augmentation from https://github.com/facebookresearch/SparseConvNet/blob/master/examples/ScanNet/data.py
13
14
def elastic (x ,gran ,mag ):
14
15
bb = np .abs (x ).max (0 ).astype (np .int32 )// gran + 3
15
16
noise = [np .random .randn (bb [0 ],bb [1 ],bb [2 ]).astype ('float32' ) for _ in range (3 )]
@@ -27,8 +28,8 @@ def g(x_):
27
28
noise = g (x )
28
29
return x + g (x )* mag
29
30
30
- ## ScanNet dataset class
31
31
class ScanNetDataset (Dataset ):
32
+ """ ScanNet data loader """
32
33
def __init__ (self , options , split , load_confidence = False , random = True ):
33
34
self .options = options
34
35
self .split = split
@@ -37,19 +38,23 @@ def __init__(self, options, split, load_confidence=False, random=True):
37
38
self .dataFolder = options .dataFolder
38
39
self .load_confidence = load_confidence
39
40
40
- with open ('split_' + split + '.txt' , 'r' ) as f :
41
+ with open ('datasets/ split_' + split + '.txt' , 'r' ) as f :
41
42
for line in f :
42
43
scene_id = line .strip ()
43
44
if len (scene_id ) < 5 or scene_id [:5 ] != 'scene' :
44
45
continue
45
46
if options .scene_id != '' and options .scene_id not in scene_id :
46
47
continue
48
+ if load_confidence :
49
+ confidence_filename = options .test_dir + '/inference/' + split + '/cache/' + scene_id + '.pth'
50
+ if not os .path .exists (confidence_filename ):
51
+ continue
52
+ pass
47
53
filename = self .dataFolder + '/' + scene_id + '/' + scene_id + '_vh_clean_2.pth'
48
54
if os .path .exists (filename ):
49
55
info = torch .load (filename )
50
56
if len (info ) == 5 :
51
57
self .imagePaths .append (filename )
52
-
53
58
#np.savetxt('semantic_val/' + scene_id + '.txt', info[2], fmt='%d')
54
59
pass
55
60
pass
@@ -58,7 +63,8 @@ def __init__(self, options, split, load_confidence=False, random=True):
58
63
continue
59
64
pass
60
65
61
- #self.imagePaths = [filename for filename in self.imagePaths if 'scene0217_00' in filename]
66
+ #self.imagePaths = [filename for filename in self.imagePaths if 'scene0217_00' in filename]
67
+ print ('the number of images' , split , len (self .imagePaths ))
62
68
63
69
if options .numTrainingImages > 0 and split == 'train' :
64
70
self .numImages = options .numTrainingImages
@@ -89,15 +95,10 @@ def __getitem__(self, index):
89
95
pass
90
96
91
97
coords , colors , labels , instances , faces = torch .load (self .imagePaths [index ])
92
- invalid_instances , = torch .load (self .imagePaths [index ].replace ('.pth' , '_invalid.pth' ))
98
+ # invalid_instances, = torch.load(self.imagePaths[index].replace('.pth', '_invalid.pth'))
93
99
94
100
labels = remapper [labels ]
95
101
96
- #neighbor_gt = torch.load(self.imagePaths[index].replace('.pth', '_neighbor.pth'))
97
- #print(neighbor_gt[0])
98
- #exit(1)
99
- #neighbor_gt = 1
100
- #print(coords.min(0), coords.max(0))
101
102
if self .split == 'train' :
102
103
m = np .eye (3 ) + np .random .randn (3 ,3 ) * 0.1
103
104
m [0 ][0 ] *= np .random .randint (2 ) * 2 - 1
@@ -117,6 +118,7 @@ def __getitem__(self, index):
117
118
#coords = elastic(coords, 20 * scale // 50, 160 * scale / 50)
118
119
pass
119
120
121
+ ## Load normals as input
120
122
if 'normal' in self .options .suffix :
121
123
points_1 = coords [faces [:, 0 ]]
122
124
points_2 = coords [faces [:, 1 ]]
@@ -134,10 +136,11 @@ def __getitem__(self, index):
134
136
if self .split == 'train' :
135
137
colors [:, :3 ] = colors [:, :3 ] + np .random .randn (3 ) * 0.1
136
138
pass
137
-
139
+
140
+ ## Load instance segmentation results to train the confidence prediction network
138
141
if self .load_confidence :
139
142
scene_id = self .imagePaths [index ].split ('/' )[- 1 ].split ('_vh_clean_2' )[0 ]
140
- info = torch .load ('test/output_normal_augment_2_ ' + self .split + '/cache/' + scene_id + '.pth' )
143
+ info = torch .load (self . options . test_dir + '/inference/ ' + self .split + '/cache/' + scene_id + '.pth' )
141
144
if len (info ) == 2 :
142
145
semantic_pred , instance_pred = info
143
146
else :
@@ -203,20 +206,9 @@ def __getitem__(self, index):
203
206
pass
204
207
205
208
coords = np .round (coords )
206
- if False :
207
- idxs = (coords .min (1 ) >= 0 ) * (coords .max (1 ) < full_scale )
208
- coords = coords [idxs ]
209
- colors = colors [idxs ]
210
- labels = labels [idxs ]
211
- instances = instances [idxs ]
212
- invalid_instances = invalid_instances [idxs ]
213
- else :
214
- #print(coords.min(0), coords.max(0))
215
- #exit(1)
216
- coords = np .clip (coords , 0 , full_scale - 1 )
217
- pass
209
+ coords = np .clip (coords , 0 , full_scale - 1 )
218
210
219
211
coords = np .concatenate ([coords , np .full ((coords .shape [0 ], 1 ), fill_value = index )], axis = - 1 )
220
212
#coords = np.concatenate([coords, np.expand_dims(instances, -1)], axis=-1)
221
- sample = [coords .astype (np .int64 ), colors .astype (np .float32 ), faces .astype (np .int64 ), labels .astype (np .int64 ), instances .astype (np .int64 ), invalid_instances . astype ( np . int64 ), self .imagePaths [index ]]
213
+ sample = [coords .astype (np .int64 ), colors .astype (np .float32 ), faces .astype (np .int64 ), labels .astype (np .int64 ), instances .astype (np .int64 ), self .imagePaths [index ]]
222
214
return sample
0 commit comments