-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathWeightedStdSSF.h
98 lines (80 loc) · 2.41 KB
/
WeightedStdSSF.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// =========================================================================================
// Structured Class-Label in Random Forests. This is a re-implementation of
// the work we presented at ICCV'11 in Barcelona, Spain.
//
// In case of using this code, please cite the following paper:
// P. Kontschieder, S. Rota Bulò, H. Bischof and M. Pelillo.
// Structured Class-Labels in Random Forests for Semantic Image Labelling. In (ICCV), 2011.
//
// Implementation by Peter Kontschieder and Samuel Rota Bulò
// October 2013
//
// =========================================================================================
#ifndef WEIGHTEDSTDSSF_H_
#define WEIGHTEDSTDSSF_H_
#include "SemanticSegmentationForests.h"
namespace vision
{
struct StdWErrorData
{
double error;
};
class WeightedStdSSF: public AbstractSemanticSegmentationTree<StdWErrorData>
{
public:
WeightedStdSSF(int seed = 0)
{
setRNG(seed);
}
virtual ~WeightedStdSSF()
{
}
protected:
void initialize(const TNode<SplitData, Prediction> *node, StdWErrorData &errorData,
Prediction &prediction) const
{
prediction.init(nClasses);
const vector<LabelledSample<Sample, Label> > &samples = this->getLSamples();
for (int i = node->getStart(); i < node->getEnd(); ++i)
{
++prediction.hist[samples[i].label.value];
++prediction.n;
}
double norm = 0;
for (int i = 0; i < nClasses; ++i)
{
prediction.p[i] = importance[i] * prediction.hist[i];
norm += prediction.p[i];
}
errorData.error = 0;
for (int i = 0; i < nClasses; ++i)
{
if (prediction.p[i] > 0)
{
prediction.p[i] /= (float)norm;
errorData.error -= (float)prediction.p[i] * log(prediction.p[i]);
}
}
}
void updateError(StdWErrorData &newError, const StdWErrorData &errorData,
const TNode<SplitData, Prediction> *node, Prediction &newLeft,
Prediction &newRight) const
{
newError = errorData;
newError.error = 0;
double lError = 0, rError = 0, lNorm = 0, rNorm = 0;
for (int j = 0; j < nClasses; ++j)
{
if(newLeft.p[j] > 0)
newError.error -= importance[j] * newLeft.hist[j] * log(newLeft.p[j]);
if(newRight.p[j] > 0)
newError.error -= importance[j] * newRight.hist[j] * log(newRight.p[j]);
}
}
double getError(const StdWErrorData &error) const
{
return error.error;
}
};
}
#endif