Skip to content

Commit 4778309

Browse files
committed
add evaluation code
1 parent d65fa96 commit 4778309

File tree

4 files changed

+291
-1
lines changed

4 files changed

+291
-1
lines changed

VOCap.m

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
function ap = VOCap(rec,prec)
2+
3+
mrec=[0 ; rec ; 1];
4+
mpre=[0 ; prec ; 0];
5+
for i=numel(mpre)-1:-1:1
6+
mpre(i)=max(mpre(i),mpre(i+1));
7+
end
8+
i=find(mrec(2:end)~=mrec(1:end-1))+1;
9+
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
10+

VOCevaldetview_validation.m

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
function VOCevaldetview_validation
2+
3+
% path of the results
4+
network = 'vgg16';
5+
region_proposal = 'selective_search';
6+
minoverlap = 0.5;
7+
result_dir = '/var/Projects/SubCNN/fast-rcnn/output/objectnet3d/objectnet3d_val';
8+
method = sprintf('%s_fast_rcnn_view_objectnet3d_%s_iter_160000', network, region_proposal);
9+
10+
poolobj = parpool;
11+
12+
opt = globals();
13+
root = opt.root;
14+
15+
% load class name
16+
classes = textread(sprintf('%s/Image_sets/classes.txt', root), '%s');
17+
num_cls = numel(classes);
18+
19+
% load validation set
20+
gtids = textread(sprintf('%s/Image_sets/val.txt', root), '%s');
21+
M = numel(gtids);
22+
23+
% read ground truth
24+
recs = cell(1, M);
25+
count = 0;
26+
for i = 1:M
27+
% read ground truth
28+
filename = sprintf('%s/Annotations/%s.mat', root, gtids{i});
29+
object = load(filename);
30+
recs{i} = object.record;
31+
count = count + numel(object.record.objects);
32+
end
33+
fprintf('load ground truth done, %d objects\n', count);
34+
35+
recalls_det = cell(num_cls, 1);
36+
precisions_det = cell(num_cls, 1);
37+
aps_det = zeros(num_cls, 1);
38+
39+
recalls_view = cell(num_cls, 1);
40+
precisions_view = cell(num_cls, 1);
41+
aps_view = zeros(num_cls, 1);
42+
similarities_view = cell(num_cls, 1);
43+
accuracies_view = cell(num_cls, 1);
44+
avps_view = zeros(num_cls, 1);
45+
avss_view = zeros(num_cls, 1);
46+
errors_view = cell(num_cls, 1);
47+
parfor k = 1:num_cls
48+
cls = classes{k};
49+
50+
% extract ground truth objects
51+
npos = 0;
52+
npos_view = 0;
53+
gt = [];
54+
for i = 1:M
55+
% extract objects of class
56+
clsinds = strmatch(cls, {recs{i}.objects(:).class}, 'exact');
57+
gt(i).BB = cat(1, recs{i}.objects(clsinds).bbox)';
58+
gt(i).det = false(length(clsinds), 1);
59+
gt(i).ignore = false(length(clsinds), 1);
60+
61+
% viewpoint
62+
num = length(clsinds);
63+
gt(i).view = cell(num, 1);
64+
gt(i).azimuth = zeros(num, 1);
65+
gt(i).elevation = zeros(num, 1);
66+
gt(i).rotation = zeros(num, 1);
67+
for j = 1:num
68+
viewpoint = recs{i}.objects(j).viewpoint;
69+
if isempty(viewpoint) == 1
70+
gt(i).ignore(j) = true;
71+
continue;
72+
end
73+
if isfield(viewpoint, 'azimuth') == 0 || isempty(viewpoint.azimuth) == 1
74+
a = viewpoint.azimuth_coarse;
75+
else
76+
a = viewpoint.azimuth;
77+
end
78+
if isfield(viewpoint, 'elevation') == 0 || isempty(viewpoint.elevation) == 1
79+
e = viewpoint.elevation_coarse;
80+
else
81+
e = viewpoint.elevation;
82+
end
83+
theta = viewpoint.theta;
84+
85+
a = a * pi / 180;
86+
e = e * pi / 180;
87+
theta = theta * pi / 180;
88+
gt(i).view{j} = rotation_matrix(a, e, theta);
89+
gt(i).azimuth(j) = a;
90+
gt(i).elevation(j) = e;
91+
gt(i).rotation(j) = theta;
92+
npos_view = npos_view + 1;
93+
end
94+
95+
npos = npos + length(clsinds);
96+
end
97+
98+
% load detections
99+
filename = sprintf('%s/%s/detections_%s.txt', result_dir, method, cls);
100+
fid = fopen(filename, 'r');
101+
C = textscan(fid, '%s %f %f %f %f %f %f %f %f');
102+
fclose(fid);
103+
104+
ids = C{1};
105+
b1 = C{2};
106+
b2 = C{3};
107+
b3 = C{4};
108+
b4 = C{5};
109+
confidence = C{6};
110+
azimuth = C{7};
111+
elevation = C{8};
112+
rotation = C{9};
113+
BB = [b1 b2 b3 b4]';
114+
115+
% sort detections by decreasing confidence
116+
[~, si]=sort(-confidence);
117+
ids = ids(si);
118+
BB = BB(:,si);
119+
azimuth = azimuth(si);
120+
elevation = elevation(si);
121+
rotation = rotation(si);
122+
123+
% assign detections to ground truth objects
124+
nd = length(confidence);
125+
tp = zeros(nd, 1);
126+
fp = zeros(nd, 1);
127+
vp = zeros(nd, 1);
128+
vs = zeros(nd, 1);
129+
ignore = false(nd, 1);
130+
vd = zeros(nd, 3);
131+
tic;
132+
for d = 1:nd
133+
% display progress
134+
if toc > 1
135+
fprintf('%s: pr: compute: %d/%d\n', cls, d, nd);
136+
tic;
137+
end
138+
139+
% find ground truth image
140+
i = find(strcmp(ids{d}, gtids) == 1);
141+
if isempty(i)
142+
error('unrecognized image "%s"', ids{d});
143+
elseif length(i)>1
144+
error('multiple image "%s"', ids{d});
145+
end
146+
147+
% assign detection to ground truth object if any
148+
bb = BB(:,d);
149+
ovmax = -inf;
150+
jmax = -1;
151+
for j = 1:size(gt(i).BB, 2)
152+
bbgt = gt(i).BB(:,j);
153+
bi = [max(bb(1),bbgt(1)) ; max(bb(2),bbgt(2)) ; min(bb(3),bbgt(3)) ; min(bb(4),bbgt(4))];
154+
iw = bi(3) - bi(1) + 1;
155+
ih = bi(4) - bi(2) + 1;
156+
if iw > 0 && ih > 0
157+
% compute overlap as area of intersection / area of union
158+
ua=(bb(3)-bb(1)+1)*(bb(4)-bb(2)+1)+...
159+
(bbgt(3)-bbgt(1)+1)*(bbgt(4)-bbgt(2)+1)-...
160+
iw*ih;
161+
ov= iw * ih / ua;
162+
if ov > ovmax
163+
ovmax = ov;
164+
jmax = j;
165+
end
166+
end
167+
end
168+
% assign detection as true positive/don't care/false positive
169+
if ovmax >= minoverlap
170+
if ~gt(i).det(jmax)
171+
tp(d) = 1; % true positive
172+
gt(i).det(jmax) = true;
173+
% compute viewpoint accuracy
174+
Rgt = gt(i).view{jmax};
175+
if isempty(Rgt) == 0
176+
R = rotation_matrix(azimuth(d), elevation(d), rotation(d));
177+
X = logm(Rgt' * R);
178+
angle = 1/sqrt(2) * norm(X, 'fro');
179+
% viewpoint similarity
180+
vs(d) = (1 + cos(angle)) / 2;
181+
% viewpoint accraucy
182+
if abs(angle) < pi/6
183+
vp(d) = 1;
184+
end
185+
186+
% compute angle errors
187+
da = abs(angdiff(azimuth(d), gt(i).azimuth(jmax)));
188+
de = abs(angdiff(elevation(d), gt(i).elevation(jmax)));
189+
dr = abs(angdiff(rotation(d), gt(i).rotation(jmax)));
190+
vd(d, 1) = da / ( da + de + dr);
191+
vd(d, 2) = de / ( da + de + dr);
192+
vd(d, 3) = dr / ( da + de + dr);
193+
end
194+
else
195+
fp(d) = 1; % false positive (multiple detection)
196+
end
197+
if gt(i).ignore(jmax)
198+
ignore(d) = true;
199+
end
200+
else
201+
fp(d) = 1; % false positive
202+
end
203+
end
204+
205+
% compute precision/recall
206+
fp_det = cumsum(fp);
207+
tp_det = cumsum(tp);
208+
rec_det = tp_det / npos;
209+
prec_det = tp_det ./ (fp_det + tp_det);
210+
ap_det = VOCap(rec_det, prec_det);
211+
212+
aps_det(k) = ap_det;
213+
recalls_det{k} = rec_det;
214+
precisions_det{k} = prec_det;
215+
fprintf('%s, ap: %f\n', cls, ap_det);
216+
217+
% compute precision/recall for view
218+
fp_view = cumsum(fp(~ignore));
219+
tp_view = cumsum(tp(~ignore));
220+
vp_view = cumsum(vp(~ignore));
221+
vs_view = cumsum(vs(~ignore));
222+
rec_view = tp_view / npos_view;
223+
prec_view = tp_view ./ (fp_view + tp_view);
224+
ap_view = VOCap(rec_view, prec_view);
225+
226+
accu_view = vp_view ./ (fp_view + tp_view);
227+
avp_view = VOCap(rec_view, accu_view);
228+
229+
sim_view = vs_view ./ (fp_view + tp_view);
230+
avs_view = VOCap(rec_view, sim_view);
231+
232+
aps_view(k) = ap_view;
233+
avps_view(k) = avp_view;
234+
avss_view(k) = avs_view;
235+
recalls_view{k} = rec_view;
236+
precisions_view{k} = prec_view;
237+
accuracies_view{k} = accu_view;
238+
similarities_view{k} = sim_view;
239+
fprintf('%s, ap view: %f, avp view %f, avs view %f\n', cls, ap_view, avp_view, avs_view);
240+
241+
% keep the view error distribution
242+
vd = vd(tp == 1 & ignore == 0, :);
243+
errors_view{k} = vd;
244+
end
245+
246+
% write to file
247+
fid = fopen(sprintf('views_%s_%d.txt', method, minoverlap*100), 'w');
248+
for i = 1:num_cls
249+
fprintf(fid, '%s %f %f %f %f\n', classes{i}, aps_det(i), aps_view(i), avps_view(i), avss_view(i));
250+
end
251+
fprintf(fid, 'mAP %f %f %f %f\n', mean(aps_det), mean(aps_view), mean(avps_view), mean(avss_view));
252+
fclose(fid);
253+
254+
% save to matfile
255+
matfile = sprintf('views_%s_%d.mat', method, minoverlap*100);
256+
save(matfile, 'recalls_det', 'precisions_det', 'aps_det', ...
257+
'recalls_view', 'precisions_view', 'aps_view', 'avps_view', 'avss_view', 'errors_view', '-v7.3');
258+
259+
delete(poolobj);
260+
261+
function d = angdiff(a, b)
262+
263+
d = a - b;
264+
if d > pi
265+
d = d - 2*pi;
266+
end
267+
if d < -pi
268+
d = d + 2*pi;
269+
end

globals.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
% --------------------------------------------------------
77
function opt = globals()
88

9-
opt.root = '/var/Projects/ObjectNet3D';
9+
opt.root = '/datasets/ObjectNet3D';
1010
opt.shapenetcore = '/var/Projects/ShapeNetCore.v1';

rotation_matrix.m

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
function R = rotation_matrix(a, e, theta)
2+
3+
a = -a;
4+
e = pi/2+e;
5+
theta=-theta;
6+
7+
% rotation matrix
8+
Rz = [cos(a) -sin(a) 0; sin(a) cos(a) 0; 0 0 1]; %rotate by a
9+
Rx = [1 0 0; 0 cos(e) -sin(e); 0 sin(e) cos(e)]; %rotate by e
10+
Rz2= [cos(theta), -sin(theta),0; sin(theta), cos(theta), 0; 0,0,1];
11+
R = Rz2*Rx*Rz;

0 commit comments

Comments
 (0)