7
7
8
8
# python imports
9
9
import math
10
- import os , sys , argparse
10
+ import os , sys
11
11
import inspect
12
12
import json
13
13
import glob
18
18
print ("Failed to import numpy package." )
19
19
sys .exit (- 1 )
20
20
21
- currentdir = os .path .dirname (os .path .abspath (inspect .getfile (inspect .currentframe ())))
22
- parentdir = os .path .dirname (currentdir )
23
- sys .path .insert (0 ,parentdir )
24
- import util
25
- import util_3d
21
+ # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
22
+ # parentdir = os.path.dirname(currentdir)
23
+ # sys.path.insert(0,parentdir)
24
+ from scripts . util import read_label_mapping
25
+ from scripts . util_3d import read_mesh_vertices
26
26
import torch
27
- import multiprocessing as mp
28
- from utils import write_ply_label
29
-
30
- TASK_TYPES = {'label' , 'instance' }
31
-
32
- parser = argparse .ArgumentParser ()
33
- #parser.add_argument('--scan_path', required=True, help='path to scannet scene (e.g., data/ScanNet/v2/scene0000_00')
34
- parser .add_argument ('--label_map_file' , default = '/gruvi/Data/chenliu/ScanNet/tasks/scannetv2-labels.combined.tsv' , help = 'path to scannetv2-labels.combined.tsv' )
35
- parser .add_argument ('--type' , default = 'instance' , help = 'task type [label or instance]' )
36
- opt = parser .parse_args ()
37
- assert opt .type in TASK_TYPES
38
-
39
- label_map = util .read_label_mapping (opt .label_map_file , label_from = 'raw_category' , label_to = 'nyu40id' )
40
- # remapper=np.ones(150)*(-100)
41
- # for i,x in enumerate([1,2,3,4,5,6,7,8,9,10,11,12,14,16,24,28,33,34,36,39]):
42
- # remapper[x]=i
43
-
27
+ import multiprocessing as mp
28
+ import functools
29
+ #from utils import write_ply_label
44
30
45
31
def read_aggregation (filename ):
46
32
assert os .path .isfile (filename )
@@ -76,44 +62,23 @@ def read_segmentation(filename):
76
62
return seg_to_verts , num_verts
77
63
78
64
79
- def export (filename ):
65
+ def export (filename , label_map ):
80
66
scan_name = filename .split ('_vh' )[0 ]
81
67
mesh_file = os .path .join (scan_name + '_vh_clean_2.ply' )
82
68
agg_file = os .path .join (scan_name + '.aggregation.json' )
83
69
seg_file = os .path .join (scan_name + '_vh_clean_2.0.010000.segs.json' )
84
70
85
- if os .path .exists (mesh_file [:- 4 ] + '.pth' ) and len (torch .load (mesh_file [:- 4 ] + '.pth' )) == 5 and False :
71
+ print (filename )
72
+ if os .path .exists (mesh_file [:- 4 ] + '.pth' ):
86
73
return
87
- print (filename )
88
74
89
- # mesh_vertices, mesh_colors, faces = util_3d. read_mesh_vertices(mesh_file)
75
+ mesh_vertices , mesh_colors , faces = read_mesh_vertices (mesh_file )
90
76
if os .path .exists (agg_file ):
91
77
object_id_to_segs , label_to_segs = read_aggregation (agg_file )
92
78
seg_to_verts , num_verts = read_segmentation (seg_file )
93
79
label_ids = np .zeros (shape = (num_verts ), dtype = np .uint32 ) # 0: unannotated
94
80
instance_ids = np .zeros (shape = (num_verts ), dtype = np .uint32 ) # 0: unannotated
95
81
96
- invalid_instance_ids = np .zeros (shape = (num_verts ), dtype = np .uint32 ) # 0: unannotated
97
- # print(len(seg_to_verts))
98
- for object_id , segs in object_id_to_segs .items ():
99
- object_verts = []
100
- for seg in segs :
101
- verts = seg_to_verts [seg ]
102
- object_verts .append (verts )
103
- continue
104
- nums = np .array ([len (verts ) for verts in object_verts ])
105
- invalid_indices = np .logical_and (nums < (0.5 * nums .sum ()), nums >= 100 )
106
- invalid_indices = invalid_indices .nonzero ()[0 ]
107
- if len (invalid_indices ) == 0 :
108
- continue
109
- seg = segs [np .random .choice (invalid_indices )]
110
- verts = seg_to_verts [seg ]
111
-
112
- invalid_instance_ids [verts ] = object_id
113
- continue
114
- torch .save ((invalid_instance_ids , ), mesh_file [:- 4 ] + '_invalid.pth' )
115
- return
116
-
117
82
# write_ply_label('test/mesh.ply', mesh_vertices, faces, label_ids)
118
83
# exit(1)
119
84
@@ -132,46 +97,19 @@ def export(filename):
132
97
label_ids = np .zeros (shape = (num_verts ), dtype = np .uint32 ) # 0: unannotated
133
98
instance_ids = np .zeros (shape = (num_verts ), dtype = np .uint32 ) # 0: unannotated
134
99
pass
135
- #point_cloud = torch.load(mesh_file[:-4] + '.pth')
136
- #print(point_cloud)
137
- #print([(v.shape, v.min(0), v.max(0), v.dtype) for v in point_cloud])
138
- #print(np.abs(remapper[label_ids] - point_cloud[2]).max())
139
- #print([(v.shape, v.min(), v.max()) for v in [mesh_vertices, mesh_colors, label_ids, instance_ids]])
140
- #exit(1)
141
100
mesh_vertices = np .ascontiguousarray (mesh_vertices - mesh_vertices .mean (0 ))
142
101
mesh_colors = np .ascontiguousarray (mesh_colors ) / 127.5 - 1
143
- # print(np.abs(mesh_vertices - point_cloud[0]).max())
144
- # print(np.abs(mesh_colors - point_cloud[1]).max())
145
- # print(np.abs(remapper[label_ids] - point_cloud[2]).max())
146
- # exit(1)
147
102
torch .save ((mesh_vertices , mesh_colors , label_ids , instance_ids , faces ), mesh_file [:- 4 ] + '.pth' )
148
103
return
149
104
150
- def main ():
151
- ROOT_FOLDER = '/gruvi/Data/chenliu/ScanNet/scans/'
152
- files = sorted (glob .glob (ROOT_FOLDER + '*/*_vh_clean_2.ply' ))
153
- #print(files)
154
- #exit(1)
105
+ def prepare_data (options ):
106
+ ROOT_FOLDER = options .dataFolder
107
+ files = sorted (glob .glob (options .dataFolder + '*/*_vh_clean_2.ply' ))
108
+ p = mp .Pool (processes = mp .cpu_count ())
155
109
156
- # files = [filename for filename in files if 'scene0568_00' in filename]
157
- # print(files)
158
- # export(files[0])
159
- # exit(1)
110
+ label_map = read_label_mapping (options .labelFile , label_from = 'raw_category' , label_to = 'nyu40id' )
160
111
161
-
162
- #print(mp.cpu_count())
163
- # for filename in files:
164
- # export(filename)
165
- # continue
166
- # exit(1)
167
- p = mp .Pool (processes = mp .cpu_count ())
168
- p .map (export , files )
112
+ p .map (functools .partial (export , label_map = label_map ), files )
169
113
p .close ()
170
114
p .join ()
171
-
172
- #for filename in files:
173
- #export(mesh_file, agg_file, seg_file, opt.label_map_file)
174
-
175
-
176
- if __name__ == '__main__' :
177
- main ()
115
+ return
0 commit comments