Skip to content

Commit 6e8e257

Browse files
committed
update model selection code and docs to match the math
1 parent f62805f commit 6e8e257

File tree

3 files changed

+30
-66
lines changed

3 files changed

+30
-66
lines changed

gtsam/hybrid/HybridBayesNet.cpp

+11-30
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
265265
/* ************************************************************************* */
266266
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
267267
/*
268-
To perform model selection, we need:
269-
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
270-
271-
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
272-
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
273-
= exp(log(q) - log(k)) = exp(-error - log(k))
274-
= exp(-(error + log(k))),
268+
To perform model selection, we need: q(mu; M, Z) = exp(-error)
275269
where error is computed at the corresponding MAP point, gbn.error(mu).
276270
277-
So we compute (error + log(k)) and exponentiate later
271+
So we compute (-error) and exponentiate later
278272
*/
279273

280274
GaussianBayesNetValTree bnTree = assembleTree();
@@ -301,13 +295,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
301295
auto trees = unzip(bn_error);
302296
AlgebraicDecisionTree<Key> errorTree = trees.second;
303297

304-
// Only compute logNormalizationConstant
305-
AlgebraicDecisionTree<Key> log_norm_constants =
306-
computeLogNormConstants(bnTree);
307-
308298
// Compute model selection term (with help from ADT methods)
309-
AlgebraicDecisionTree<Key> modelSelectionTerm =
310-
computeModelSelectionTerm(errorTree, log_norm_constants);
299+
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
300+
301+
// Exponentiate using our scheme
302+
double max_log = modelSelectionTerm.max();
303+
modelSelectionTerm = DecisionTree<Key, double>(
304+
modelSelectionTerm,
305+
[&max_log](const double &x) { return std::exp(x - max_log); });
306+
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
307+
311308
return modelSelectionTerm;
312309
}
313310

@@ -531,20 +528,4 @@ AlgebraicDecisionTree<Key> computeLogNormConstants(
531528
return log_norm_constants;
532529
}
533530

534-
/* ************************************************************************* */
535-
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
536-
const AlgebraicDecisionTree<Key> &errorTree,
537-
const AlgebraicDecisionTree<Key> &log_norm_constants) {
538-
AlgebraicDecisionTree<Key> modelSelectionTerm =
539-
(errorTree + log_norm_constants) * -1;
540-
541-
double max_log = modelSelectionTerm.max();
542-
modelSelectionTerm = DecisionTree<Key, double>(
543-
modelSelectionTerm,
544-
[&max_log](const double &x) { return std::exp(x - max_log); });
545-
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
546-
547-
return modelSelectionTerm;
548-
}
549-
550531
} // namespace gtsam

gtsam/hybrid/HybridBayesNet.h

+11-30
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
128128
*/
129129
GaussianBayesNetValTree assembleTree() const;
130130

131-
/*
132-
Compute L(M;Z), the likelihood of the discrete model M
133-
given the measurements Z.
134-
This is called the model selection term.
135-
136-
To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M).
137-
138-
By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M),
139-
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
140-
the joint Gaussian distribution.
141-
142-
This can be computed by multiplying all the exponentiated errors
143-
of each of the conditionals.
144-
145-
Return a tree where each leaf value is L(M_i;Z).
146-
*/
131+
/**
132+
* @brief Compute the model selection term q(μ_X; M, Z)
133+
* given the error for each discrete assignment.
134+
*
135+
* The q(μ) terms are obtained as a result of elimination
136+
* as part of the separator factor.
137+
*
138+
* Perform normalization to handle underflow issues.
139+
*
140+
* @return AlgebraicDecisionTree<Key>
141+
*/
147142
AlgebraicDecisionTree<Key> modelSelection() const;
148143

149144
/**
@@ -301,18 +296,4 @@ GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
301296
AlgebraicDecisionTree<Key> computeLogNormConstants(
302297
const GaussianBayesNetValTree &bnTree);
303298

304-
/**
305-
* @brief Compute the model selection term L(M; Z, X) given the error
306-
* and log normalization constants.
307-
*
308-
* Perform normalization to handle underflow issues.
309-
*
310-
* @param errorTree
311-
* @param log_norm_constants
312-
* @return AlgebraicDecisionTree<Key>
313-
*/
314-
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
315-
const AlgebraicDecisionTree<Key> &errorTree,
316-
const AlgebraicDecisionTree<Key> &log_norm_constants);
317-
318299
} // namespace gtsam

gtsam/hybrid/HybridBayesTree.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,15 @@ AlgebraicDecisionTree<Key> HybridBayesTree::modelSelection() const {
111111
auto trees = unzip(bn_error);
112112
AlgebraicDecisionTree<Key> errorTree = trees.second;
113113

114-
// Only compute logNormalizationConstant
115-
AlgebraicDecisionTree<Key> log_norm_constants =
116-
computeLogNormConstants(bnTree);
117-
118114
// Compute model selection term (with help from ADT methods)
119-
AlgebraicDecisionTree<Key> modelSelectionTerm =
120-
computeModelSelectionTerm(errorTree, log_norm_constants);
115+
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
116+
117+
// Exponentiate using our scheme
118+
double max_log = modelSelectionTerm.max();
119+
modelSelectionTerm = DecisionTree<Key, double>(
120+
modelSelectionTerm,
121+
[&max_log](const double& x) { return std::exp(x - max_log); });
122+
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
121123

122124
return modelSelectionTerm;
123125
}

0 commit comments

Comments
 (0)