Skip to content

Commit e346e21

Browse files
author
csguestp
committed
semantic counts
1 parent a1b92c7 commit e346e21

9 files changed

+62
-1015
lines changed

Diff for: datasets/scannet_dataset.py

+5-29
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ def g(x_):
2929

3030
## ScanNet dataset class
3131
class ScanNetDataset(Dataset):
32-
def __init__(self, options, split, random=True):
32+
def __init__(self, options, split, load_confidence=False, random=True):
3333
self.options = options
3434
self.split = split
3535
self.random = random
3636
self.imagePaths = []
37-
self.dataFolder = '/gruvi/Data/chenliu/ScanNet/scans/'
38-
37+
self.dataFolder = options.dataFolder
38+
self.load_confidence = load_confidence
39+
3940
with open('split_' + split + '.txt', 'r') as f:
4041
for line in f:
4142
scene_id = line.strip()
@@ -134,32 +135,7 @@ def __getitem__(self, index):
134135
colors[:, :3] = colors[:, :3] + np.random.randn(3) * 0.1
135136
pass
136137

137-
if self.options.trainingMode == 'semantic':
138-
unique_instances, indices, instances = np.unique(instances, return_index=True, return_inverse=True)
139-
labels = labels[indices]
140-
labels[labels == -100] = 20
141-
new_coords = np.zeros(coords.shape, dtype=coords.dtype)
142-
for instance in range(len(unique_instances)):
143-
instance_mask = instances == instance
144-
instance_coords = coords[instance_mask]
145-
mins = instance_coords.min(0)
146-
maxs = instance_coords.max(0)
147-
max_range = (maxs - mins).max()
148-
if self.split == 'train':
149-
padding = (maxs - mins) * np.random.random(3) * 0.1
150-
else:
151-
padding = max_range * 0.05
152-
pass
153-
max_range += padding * 2
154-
mins = (mins + maxs) / 2 - max_range / 2
155-
instance_coords = np.clip(np.round((instance_coords - mins) / max_range * full_scale), 0, full_scale - 1)
156-
new_coords[instance_mask] = instance_coords
157-
continue
158-
coords = np.concatenate([new_coords, np.expand_dims(instances, -1)], axis=-1)
159-
sample = [coords.astype(np.int64), colors.astype(np.float32), faces.astype(np.int64), labels.astype(np.int64), instances.astype(np.int64), self.imagePaths[index]]
160-
return sample
161-
162-
if self.options.trainingMode == 'confidence':
138+
if self.load_confidence:
163139
scene_id = self.imagePaths[index].split('/')[-1].split('_vh_clean_2')[0]
164140
info = torch.load('test/output_normal_augment_2_' + self.split + '/cache/' + scene_id + '.pth')
165141
if len(info) == 2:

Diff for: datasets/semantic_counts_pixelwise.npy

456 Bytes
Binary file not shown.

Diff for: models/instance.py

-298
Large diffs are not rendered by default.

Diff for: options.py

+9-36
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,10 @@ def parse_args():
3535
default=2.5e-4, type=float)
3636
parser.add_argument('--numEpochs', dest='numEpochs',
3737
help='the number of epochs',
38-
default=1000, type=int)
38+
default=50, type=int)
3939
parser.add_argument('--startEpoch', dest='startEpoch',
4040
help='starting epoch index',
4141
default=-1, type=int)
42-
parser.add_argument('--modelType', dest='modelType',
43-
help='model type',
44-
default='', type=str)
45-
parser.add_argument('--heatmapThreshold', dest='heatmapThreshold',
46-
help='heatmap threshold for positive predictions',
47-
default=0.5, type=float)
48-
parser.add_argument('--distanceThreshold3D', dest='distanceThreshold3D',
49-
help='distance threshold 3D',
50-
default=0.2, type=float)
51-
parser.add_argument('--distanceThreshold2D', dest='distanceThreshold2D',
52-
help='distance threshold 2D',
53-
default=20, type=float)
54-
parser.add_argument('--numInputPlanes', dest='numInputPlanes',
55-
help='the number of input planes',
56-
default=1024, type=int)
57-
parser.add_argument('--numOutputPlanes', dest='numOutputPlanes',
58-
help='the number of output planes',
59-
default=10, type=int)
60-
parser.add_argument('--numInputClasses', dest='numInputClasses',
61-
help='the number of input classes',
62-
default=0, type=int)
63-
parser.add_argument('--numOutputClasses', dest='numOutputClasses',
64-
help='the number of output classes',
65-
default=0, type=int)
6642
parser.add_argument('--width', dest='width',
6743
help='input width',
6844
default=256, type=int)
@@ -77,32 +53,29 @@ def parse_args():
7753
default=50, type=int)
7854
parser.add_argument('--numScales', dest='numScales',
7955
help='the number of scales',
80-
default=1, type=int)
56+
default=2, type=int)
8157
parser.add_argument('--numCrossScales', dest='numCrossScales',
8258
help='the number of cross scales',
8359
default=0, type=int)
8460
parser.add_argument('--numNeighbors', dest='numNeighbors',
8561
help='the number of neighbors',
8662
default=6, type=int)
87-
parser.add_argument('--outputScale', dest='outputScale',
88-
help='output scale',
89-
default=256, type=int)
90-
parser.add_argument('--negativeWeights', dest='negativeWeights',
91-
help='negative weights',
92-
default='531111', type=str)
93-
parser.add_argument('--trainingMode', dest='trainingMode',
94-
help='training mode',
95-
default='all', type=str)
9663
## Flags
9764
parser.add_argument('--visualizeMode', dest='visualizeMode',
9865
help='visualization mode',
9966
default='', type=str)
10067
parser.add_argument('--suffix', dest='suffix',
10168
help='suffix to distinguish experiments',
102-
default='', type=str)
69+
default='normal_augment', type=str)
10370
parser.add_argument('--useCache', dest='useCache',
10471
help='use cache instead of re-computing existing examples',
10572
default=0, type=int)
73+
parser.add_argument('--dataFolder', dest='dataFolder',
74+
help='data folder',
75+
default='/gruvi/Data/chenliu/ScanNet/scans/', type=str)
76+
parser.add_argument('--labelFile', dest='labelFile',
77+
help='path to scannetv2-labels.combined.tsv',
78+
default='/gruvi/Data/chenliu/ScanNet/tasks/scannetv2-labels.combined.tsv', type=str)
10679

10780
args = parser.parse_args()
10881
return args

Diff for: scripts/prepare_data.py

+20-82
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# python imports
99
import math
10-
import os, sys, argparse
10+
import os, sys
1111
import inspect
1212
import json
1313
import glob
@@ -18,29 +18,15 @@
1818
print("Failed to import numpy package.")
1919
sys.exit(-1)
2020

21-
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
22-
parentdir = os.path.dirname(currentdir)
23-
sys.path.insert(0,parentdir)
24-
import util
25-
import util_3d
21+
# currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
22+
# parentdir = os.path.dirname(currentdir)
23+
# sys.path.insert(0,parentdir)
24+
from scripts.util import read_label_mapping
25+
from scripts.util_3d import read_mesh_vertices
2626
import torch
27-
import multiprocessing as mp
28-
from utils import write_ply_label
29-
30-
TASK_TYPES = {'label', 'instance'}
31-
32-
parser = argparse.ArgumentParser()
33-
#parser.add_argument('--scan_path', required=True, help='path to scannet scene (e.g., data/ScanNet/v2/scene0000_00')
34-
parser.add_argument('--label_map_file', default='/gruvi/Data/chenliu/ScanNet/tasks/scannetv2-labels.combined.tsv', help='path to scannetv2-labels.combined.tsv')
35-
parser.add_argument('--type', default='instance', help='task type [label or instance]')
36-
opt = parser.parse_args()
37-
assert opt.type in TASK_TYPES
38-
39-
label_map = util.read_label_mapping(opt.label_map_file, label_from='raw_category', label_to='nyu40id')
40-
# remapper=np.ones(150)*(-100)
41-
# for i,x in enumerate([1,2,3,4,5,6,7,8,9,10,11,12,14,16,24,28,33,34,36,39]):
42-
# remapper[x]=i
43-
27+
import multiprocessing as mp
28+
import functools
29+
#from utils import write_ply_label
4430

4531
def read_aggregation(filename):
4632
assert os.path.isfile(filename)
@@ -76,44 +62,23 @@ def read_segmentation(filename):
7662
return seg_to_verts, num_verts
7763

7864

79-
def export(filename):
65+
def export(filename, label_map):
8066
scan_name = filename.split('_vh')[0]
8167
mesh_file = os.path.join(scan_name + '_vh_clean_2.ply')
8268
agg_file = os.path.join(scan_name + '.aggregation.json')
8369
seg_file = os.path.join(scan_name + '_vh_clean_2.0.010000.segs.json')
8470

85-
if os.path.exists(mesh_file[:-4] + '.pth') and len(torch.load(mesh_file[:-4] + '.pth')) == 5 and False:
71+
print(filename)
72+
if os.path.exists(mesh_file[:-4] + '.pth'):
8673
return
87-
print(filename)
8874

89-
#mesh_vertices, mesh_colors, faces = util_3d.read_mesh_vertices(mesh_file)
75+
mesh_vertices, mesh_colors, faces = read_mesh_vertices(mesh_file)
9076
if os.path.exists(agg_file):
9177
object_id_to_segs, label_to_segs = read_aggregation(agg_file)
9278
seg_to_verts, num_verts = read_segmentation(seg_file)
9379
label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
9480
instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
9581

96-
invalid_instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
97-
# print(len(seg_to_verts))
98-
for object_id, segs in object_id_to_segs.items():
99-
object_verts = []
100-
for seg in segs:
101-
verts = seg_to_verts[seg]
102-
object_verts.append(verts)
103-
continue
104-
nums = np.array([len(verts) for verts in object_verts])
105-
invalid_indices = np.logical_and(nums < (0.5 * nums.sum()), nums >= 100)
106-
invalid_indices = invalid_indices.nonzero()[0]
107-
if len(invalid_indices) == 0:
108-
continue
109-
seg = segs[np.random.choice(invalid_indices)]
110-
verts = seg_to_verts[seg]
111-
112-
invalid_instance_ids[verts] = object_id
113-
continue
114-
torch.save((invalid_instance_ids, ), mesh_file[:-4] + '_invalid.pth')
115-
return
116-
11782
# write_ply_label('test/mesh.ply', mesh_vertices, faces, label_ids)
11883
# exit(1)
11984

@@ -132,46 +97,19 @@ def export(filename):
13297
label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
13398
instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
13499
pass
135-
#point_cloud = torch.load(mesh_file[:-4] + '.pth')
136-
#print(point_cloud)
137-
#print([(v.shape, v.min(0), v.max(0), v.dtype) for v in point_cloud])
138-
#print(np.abs(remapper[label_ids] - point_cloud[2]).max())
139-
#print([(v.shape, v.min(), v.max()) for v in [mesh_vertices, mesh_colors, label_ids, instance_ids]])
140-
#exit(1)
141100
mesh_vertices = np.ascontiguousarray(mesh_vertices - mesh_vertices.mean(0))
142101
mesh_colors = np.ascontiguousarray(mesh_colors) / 127.5 - 1
143-
# print(np.abs(mesh_vertices - point_cloud[0]).max())
144-
# print(np.abs(mesh_colors - point_cloud[1]).max())
145-
# print(np.abs(remapper[label_ids] - point_cloud[2]).max())
146-
# exit(1)
147102
torch.save((mesh_vertices, mesh_colors, label_ids, instance_ids, faces), mesh_file[:-4] + '.pth')
148103
return
149104

150-
def main():
151-
ROOT_FOLDER = '/gruvi/Data/chenliu/ScanNet/scans/'
152-
files = sorted(glob.glob(ROOT_FOLDER + '*/*_vh_clean_2.ply'))
153-
#print(files)
154-
#exit(1)
105+
def prepare_data(options):
106+
ROOT_FOLDER = options.dataFolder
107+
files = sorted(glob.glob(options.dataFolder + '*/*_vh_clean_2.ply'))
108+
p = mp.Pool(processes=mp.cpu_count())
155109

156-
# files = [filename for filename in files if 'scene0568_00' in filename]
157-
# print(files)
158-
# export(files[0])
159-
# exit(1)
110+
label_map = read_label_mapping(options.labelFile, label_from='raw_category', label_to='nyu40id')
160111

161-
162-
#print(mp.cpu_count())
163-
# for filename in files:
164-
# export(filename)
165-
# continue
166-
# exit(1)
167-
p = mp.Pool(processes=mp.cpu_count())
168-
p.map(export, files)
112+
p.map(functools.partial(export, label_map=label_map), files)
169113
p.close()
170114
p.join()
171-
172-
#for filename in files:
173-
#export(mesh_file, agg_file, seg_file, opt.label_map_file)
174-
175-
176-
if __name__ == '__main__':
177-
main()
115+
return

Diff for: scripts/util_3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
print("pip install plyfile")
1515
sys.exit(-1)
1616

17-
import util
17+
import scripts.util
1818

1919

2020
# matrix: 4x4 np array

0 commit comments

Comments
 (0)