@@ -265,16 +265,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
265
265
/* ************************************************************************* */
266
266
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection () const {
267
267
/*
268
- To perform model selection, we need:
269
- q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
270
-
271
- If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
272
- thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
273
- = exp(log(q) - log(k)) = exp(-error - log(k))
274
- = exp(-(error + log(k))),
268
+ To perform model selection, we need: q(mu; M, Z) = exp(-error)
275
269
where error is computed at the corresponding MAP point, gbn.error(mu).
276
270
277
- So we compute (error + log(k) ) and exponentiate later
271
+ So we compute (- error) and exponentiate later
278
272
*/
279
273
280
274
GaussianBayesNetValTree bnTree = assembleTree ();
@@ -301,13 +295,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
301
295
auto trees = unzip (bn_error);
302
296
AlgebraicDecisionTree<Key> errorTree = trees.second ;
303
297
304
- // Only compute logNormalizationConstant
305
- AlgebraicDecisionTree<Key> log_norm_constants =
306
- computeLogNormConstants (bnTree);
307
-
308
298
// Compute model selection term (with help from ADT methods)
309
- AlgebraicDecisionTree<Key> modelSelectionTerm =
310
- computeModelSelectionTerm (errorTree, log_norm_constants);
299
+ AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1 ;
300
+
301
+ // Exponentiate using our scheme
302
+ double max_log = modelSelectionTerm.max ();
303
+ modelSelectionTerm = DecisionTree<Key, double >(
304
+ modelSelectionTerm,
305
+ [&max_log](const double &x) { return std::exp (x - max_log); });
306
+ modelSelectionTerm = modelSelectionTerm.normalize (modelSelectionTerm.sum ());
307
+
311
308
return modelSelectionTerm;
312
309
}
313
310
@@ -531,20 +528,4 @@ AlgebraicDecisionTree<Key> computeLogNormConstants(
531
528
return log_norm_constants;
532
529
}
533
530
534
- /* ************************************************************************* */
535
- AlgebraicDecisionTree<Key> computeModelSelectionTerm (
536
- const AlgebraicDecisionTree<Key> &errorTree,
537
- const AlgebraicDecisionTree<Key> &log_norm_constants) {
538
- AlgebraicDecisionTree<Key> modelSelectionTerm =
539
- (errorTree + log_norm_constants) * -1 ;
540
-
541
- double max_log = modelSelectionTerm.max ();
542
- modelSelectionTerm = DecisionTree<Key, double >(
543
- modelSelectionTerm,
544
- [&max_log](const double &x) { return std::exp (x - max_log); });
545
- modelSelectionTerm = modelSelectionTerm.normalize (modelSelectionTerm.sum ());
546
-
547
- return modelSelectionTerm;
548
- }
549
-
550
531
} // namespace gtsam
0 commit comments