28
28
#include < gtsam/linear/GaussianFactorGraph.h>
29
29
30
30
namespace gtsam {
31
+ HybridGaussianFactor::FactorValuePairs GetFactorValuePairs (
32
+ const HybridGaussianConditional::Conditionals &conditionals) {
33
+ auto func = [](const GaussianConditional::shared_ptr &conditional)
34
+ -> GaussianFactorValuePair {
35
+ double value = 0.0 ;
36
+ // Check if conditional is pruned
37
+ if (conditional) {
38
+ // Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
39
+ value = -conditional->logNormalizationConstant ();
40
+ }
41
+ return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
42
+ };
43
+ return HybridGaussianFactor::FactorValuePairs (conditionals, func);
44
+ }
31
45
32
46
HybridGaussianConditional::HybridGaussianConditional (
33
47
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
34
48
const DiscreteKeys &discreteParents,
35
49
const HybridGaussianConditional::Conditionals &conditionals)
36
50
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
37
- discreteParents),
51
+ discreteParents, GetFactorValuePairs(conditionals) ),
38
52
BaseConditional (continuousFrontals.size()),
39
53
conditionals_(conditionals) {
40
- // Calculate logConstant_ as the maximum of the log constants of the
54
+ // Calculate logConstant_ as the minimum of the log normalizers of the
41
55
// conditionals, by visiting the decision tree:
42
- logConstant_ = - std::numeric_limits<double >::infinity ();
56
+ logConstant_ = std::numeric_limits<double >::infinity ();
43
57
conditionals_.visit (
44
58
[this ](const GaussianConditional::shared_ptr &conditional) {
45
59
if (conditional) {
46
- this ->logConstant_ = std::max (
47
- this ->logConstant_ , conditional->logNormalizationConstant ());
60
+ this ->logConstant_ = std::min (
61
+ this ->logConstant_ , - conditional->logNormalizationConstant ());
48
62
}
49
63
});
50
64
}
@@ -64,29 +78,14 @@ HybridGaussianConditional::HybridGaussianConditional(
64
78
DiscreteKeys{discreteParent},
65
79
Conditionals ({discreteParent}, conditionals)) {}
66
80
67
- /* *******************************************************************************/
68
- // TODO(dellaert): This is copy/paste: HybridGaussianConditional should be
69
- // derived from HybridGaussianFactor, no?
70
- GaussianFactorGraphTree HybridGaussianConditional::add (
71
- const GaussianFactorGraphTree &sum) const {
72
- using Y = GaussianFactorGraph;
73
- auto add = [](const Y &graph1, const Y &graph2) {
74
- auto result = graph1;
75
- result.push_back (graph2);
76
- return result;
77
- };
78
- const auto tree = asGaussianFactorGraphTree ();
79
- return sum.empty () ? tree : sum.apply (tree, add);
80
- }
81
-
82
81
/* *******************************************************************************/
83
82
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree ()
84
83
const {
85
84
auto wrap = [this ](const GaussianConditional::shared_ptr &gc) {
86
85
// First check if conditional has not been pruned
87
86
if (gc) {
88
87
const double Cgm_Kgcm =
89
- this ->logConstant_ - gc->logNormalizationConstant ();
88
+ - this ->logConstant_ - gc->logNormalizationConstant ();
90
89
// If there is a difference in the covariances, we need to account for
91
90
// that since the error is dependent on the mode.
92
91
if (Cgm_Kgcm > 0.0 ) {
@@ -157,7 +156,8 @@ void HybridGaussianConditional::print(const std::string &s,
157
156
std::cout << " (" << formatter (dk.first ) << " , " << dk.second << " ), " ;
158
157
}
159
158
std::cout << std::endl
160
- << " logNormalizationConstant: " << logConstant_ << std::endl
159
+ << " logNormalizationConstant: " << logNormalizationConstant ()
160
+ << std::endl
161
161
<< std::endl;
162
162
conditionals_.print (
163
163
" " , [&](Key k) { return formatter (k); },
@@ -216,7 +216,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
216
216
-> GaussianFactorValuePair {
217
217
const auto likelihood_m = conditional->likelihood (given);
218
218
const double Cgm_Kgcm =
219
- logConstant_ - conditional->logNormalizationConstant ();
219
+ - logConstant_ - conditional->logNormalizationConstant ();
220
220
if (Cgm_Kgcm == 0.0 ) {
221
221
return {likelihood_m, 0.0 };
222
222
} else {
@@ -330,7 +330,7 @@ double HybridGaussianConditional::conditionalError(
330
330
// Check if valid pointer
331
331
if (conditional) {
332
332
return conditional->error (continuousValues) + //
333
- logConstant_ - conditional->logNormalizationConstant ();
333
+ - logConstant_ - conditional->logNormalizationConstant ();
334
334
} else {
335
335
// If not valid, pointer, it means this conditional was pruned,
336
336
// so we return maximum error.
0 commit comments