Skip to content

Commit 017044e

Browse files
authored
Merge pull request #1836 from borglab/improved-api
2 parents 08967d1 + 245f3e0 commit 017044e

File tree

3 files changed

+68
-42
lines changed

3 files changed

+68
-42
lines changed

gtsam/hybrid/HybridGaussianConditional.cpp

+24-24
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,37 @@
2828
#include <gtsam/linear/GaussianFactorGraph.h>
2929

3030
namespace gtsam {
31+
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
32+
const HybridGaussianConditional::Conditionals &conditionals) {
33+
auto func = [](const GaussianConditional::shared_ptr &conditional)
34+
-> GaussianFactorValuePair {
35+
double value = 0.0;
36+
// Check if conditional is pruned
37+
if (conditional) {
38+
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
39+
value = -conditional->logNormalizationConstant();
40+
}
41+
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
42+
};
43+
return HybridGaussianFactor::FactorValuePairs(conditionals, func);
44+
}
3145

3246
HybridGaussianConditional::HybridGaussianConditional(
3347
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
3448
const DiscreteKeys &discreteParents,
3549
const HybridGaussianConditional::Conditionals &conditionals)
3650
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
37-
discreteParents),
51+
discreteParents, GetFactorValuePairs(conditionals)),
3852
BaseConditional(continuousFrontals.size()),
3953
conditionals_(conditionals) {
40-
// Calculate logConstant_ as the maximum of the log constants of the
54+
// Calculate logConstant_ as the minimum of the log normalizers of the
4155
// conditionals, by visiting the decision tree:
42-
logConstant_ = -std::numeric_limits<double>::infinity();
56+
logConstant_ = std::numeric_limits<double>::infinity();
4357
conditionals_.visit(
4458
[this](const GaussianConditional::shared_ptr &conditional) {
4559
if (conditional) {
46-
this->logConstant_ = std::max(
47-
this->logConstant_, conditional->logNormalizationConstant());
60+
this->logConstant_ = std::min(
61+
this->logConstant_, -conditional->logNormalizationConstant());
4862
}
4963
});
5064
}
@@ -64,29 +78,14 @@ HybridGaussianConditional::HybridGaussianConditional(
6478
DiscreteKeys{discreteParent},
6579
Conditionals({discreteParent}, conditionals)) {}
6680

67-
/* *******************************************************************************/
68-
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be
69-
// derived from HybridGaussianFactor, no?
70-
GaussianFactorGraphTree HybridGaussianConditional::add(
71-
const GaussianFactorGraphTree &sum) const {
72-
using Y = GaussianFactorGraph;
73-
auto add = [](const Y &graph1, const Y &graph2) {
74-
auto result = graph1;
75-
result.push_back(graph2);
76-
return result;
77-
};
78-
const auto tree = asGaussianFactorGraphTree();
79-
return sum.empty() ? tree : sum.apply(tree, add);
80-
}
81-
8281
/* *******************************************************************************/
8382
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
8483
const {
8584
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
8685
// First check if conditional has not been pruned
8786
if (gc) {
8887
const double Cgm_Kgcm =
89-
this->logConstant_ - gc->logNormalizationConstant();
88+
-this->logConstant_ - gc->logNormalizationConstant();
9089
// If there is a difference in the covariances, we need to account for
9190
// that since the error is dependent on the mode.
9291
if (Cgm_Kgcm > 0.0) {
@@ -157,7 +156,8 @@ void HybridGaussianConditional::print(const std::string &s,
157156
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
158157
}
159158
std::cout << std::endl
160-
<< " logNormalizationConstant: " << logConstant_ << std::endl
159+
<< " logNormalizationConstant: " << logNormalizationConstant()
160+
<< std::endl
161161
<< std::endl;
162162
conditionals_.print(
163163
"", [&](Key k) { return formatter(k); },
@@ -216,7 +216,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
216216
-> GaussianFactorValuePair {
217217
const auto likelihood_m = conditional->likelihood(given);
218218
const double Cgm_Kgcm =
219-
logConstant_ - conditional->logNormalizationConstant();
219+
-logConstant_ - conditional->logNormalizationConstant();
220220
if (Cgm_Kgcm == 0.0) {
221221
return {likelihood_m, 0.0};
222222
} else {
@@ -330,7 +330,7 @@ double HybridGaussianConditional::conditionalError(
330330
// Check if valid pointer
331331
if (conditional) {
332332
return conditional->error(continuousValues) + //
333-
logConstant_ - conditional->logNormalizationConstant();
333+
-logConstant_ - conditional->logNormalizationConstant();
334334
} else {
335335
// If not valid, pointer, it means this conditional was pruned,
336336
// so we return maximum error.

gtsam/hybrid/HybridGaussianConditional.h

+12-17
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,22 @@ class HybridValues;
5151
* @ingroup hybrid
5252
*/
5353
class GTSAM_EXPORT HybridGaussianConditional
54-
: public HybridFactor,
55-
public Conditional<HybridFactor, HybridGaussianConditional> {
54+
: public HybridGaussianFactor,
55+
public Conditional<HybridGaussianFactor, HybridGaussianConditional> {
5656
public:
5757
using This = HybridGaussianConditional;
58-
using shared_ptr = std::shared_ptr<HybridGaussianConditional>;
59-
using BaseFactor = HybridFactor;
60-
using BaseConditional = Conditional<HybridFactor, HybridGaussianConditional>;
58+
using shared_ptr = std::shared_ptr<This>;
59+
using BaseFactor = HybridGaussianFactor;
60+
using BaseConditional = Conditional<BaseFactor, HybridGaussianConditional>;
6161

6262
/// typedef for Decision Tree of Gaussian Conditionals
6363
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
6464

6565
private:
6666
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
67-
double logConstant_; ///< log of the normalization constant.
67+
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
68+
///< Take advantage of the neg-log space so everything is a minimization
69+
double logConstant_;
6870

6971
/**
7072
* @brief Convert a HybridGaussianConditional of conditionals into
@@ -107,8 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
107109
const Conditionals &conditionals);
108110

109111
/**
110-
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian conditionals.
111-
* The DecisionTree-based constructor is preferred over this one.
112+
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
113+
* conditionals. The DecisionTree-based constructor is preferred over this
114+
* one.
112115
*
113116
* @param continuousFrontals The continuous frontal variables
114117
* @param continuousParents The continuous parent variables
@@ -149,7 +152,7 @@ class GTSAM_EXPORT HybridGaussianConditional
149152

150153
/// The log normalization constant is max of the the individual
151154
/// log-normalization constants.
152-
double logNormalizationConstant() const override { return logConstant_; }
155+
double logNormalizationConstant() const override { return -logConstant_; }
153156

154157
/**
155158
* Create a likelihood factor for a hybrid Gaussian conditional,
@@ -232,14 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional
232235
*/
233236
void prune(const DecisionTreeFactor &discreteProbs);
234237

235-
/**
236-
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
237-
* maintaining the decision tree structure.
238-
*
239-
* @param sum Decision Tree of Gaussian Factor Graphs
240-
* @return GaussianFactorGraphTree
241-
*/
242-
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
243238
/// @}
244239

245240
private:

gtsam/hybrid/tests/testHybridGaussianConditional.cpp

+32-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) {
100100
auto actual = hybrid_conditional.errorTree(vv);
101101

102102
// Check result.
103-
std::vector<DiscreteKey> discrete_keys = {mode};
103+
DiscreteKeys discrete_keys{mode};
104104
std::vector<double> leaves = {conditionals[0]->error(vv),
105105
conditionals[1]->error(vv)};
106106
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
@@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) {
172172
EXPECT(continuousParentKeys[0] == X(0));
173173
}
174174

175+
/* ************************************************************************* */
176+
/// Check error with mode dependent constants.
177+
TEST(HybridGaussianConditional, Error2) {
178+
using namespace mode_dependent_constants;
179+
auto actual = hybrid_conditional.errorTree(vv);
180+
181+
// Check result.
182+
DiscreteKeys discrete_keys{mode};
183+
double logNormalizer0 = -conditionals[0]->logNormalizationConstant();
184+
double logNormalizer1 = -conditionals[1]->logNormalizationConstant();
185+
double minLogNormalizer = std::min(logNormalizer0, logNormalizer1);
186+
187+
// Expected error is e(X) + log(|2πΣ|).
188+
// We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative.
189+
std::vector<double> leaves = {
190+
conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer,
191+
conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer};
192+
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
193+
194+
EXPECT(assert_equal(expected, actual, 1e-6));
195+
196+
// Check for non-tree version.
197+
for (size_t mode : {0, 1}) {
198+
const HybridValues hv{vv, {{M(0), mode}}};
199+
EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) -
200+
conditionals[mode]->logNormalizationConstant() -
201+
minLogNormalizer,
202+
hybrid_conditional.error(hv), 1e-8);
203+
}
204+
}
205+
175206
/* ************************************************************************* */
176207
/// Check that the likelihood is proportional to the conditional density given
177208
/// the measurements.

0 commit comments

Comments
 (0)