-
Notifications
You must be signed in to change notification settings - Fork 804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AlgebraicDecisionTree Helpers #1696
Changes from 31 commits
7695fd6
39f7ac2
c374a26
ed5ef66
50670da
af490e9
c004bd8
7b56c96
e549a9b
b2638c8
6f09be5
3660429
07ddec5
c6584f6
ebcf958
1e298be
b4f07a0
6f4343c
409938f
93c824c
b20d33d
3a89653
6f66d04
0d05810
651f999
114c86f
82e0c0d
8a61c49
3ba54eb
bb95cd4
6d50de8
9ad7697
502e8cf
a80b5d4
0430fee
afcb933
c5bfd52
e9e2ef9
538871a
1501b7c
eb9ea78
a9cf4a0
2a080bb
113a7f8
2430abb
3c722ac
d4e5a9b
fef929f
654bad7
6b1d89d
6d9fc8e
fd2062b
cea84b8
73d971a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,16 @@ static std::mt19937_64 kRandomNumberGenerator(42); | |
|
||
namespace gtsam { | ||
|
||
/* ************************************************************************ */ | ||
// Throw a runtime exception for method specified in string s, | ||
// and conditional f: | ||
static void throwRuntimeError(const std::string &s, | ||
const std::shared_ptr<HybridConditional> &f) { | ||
auto &fr = *f; | ||
throw std::runtime_error(s + " not implemented for conditional type " + | ||
demangle(typeid(fr).name()) + "."); | ||
} | ||
|
||
/* ************************************************************************* */ | ||
void HybridBayesNet::print(const std::string &s, | ||
const KeyFormatter &formatter) const { | ||
|
@@ -217,18 +227,160 @@ GaussianBayesNet HybridBayesNet::choose( | |
return gbn; | ||
} | ||
|
||
/* ************************************************************************ */ | ||
static GaussianBayesNetTree addGaussian( | ||
const GaussianBayesNetTree &gfgTree, | ||
const GaussianConditional::shared_ptr &factor) { | ||
// If the decision tree is not initialized, then initialize it. | ||
if (gfgTree.empty()) { | ||
GaussianBayesNet result{factor}; | ||
return GaussianBayesNetTree(result); | ||
} else { | ||
auto add = [&factor](const GaussianBayesNet &graph) { | ||
auto result = graph; | ||
result.push_back(factor); | ||
return result; | ||
}; | ||
return gfgTree.apply(add); | ||
} | ||
} | ||
|
||
/* ************************************************************************ */ | ||
GaussianBayesNetValTree HybridBayesNet::assembleTree() const { | ||
GaussianBayesNetTree result; | ||
|
||
for (auto &f : factors_) { | ||
// TODO(dellaert): just use a virtual method defined in HybridFactor. | ||
if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(f)) { | ||
result = gm->add(result); | ||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(f)) { | ||
if (auto gm = hc->asMixture()) { | ||
result = gm->add(result); | ||
} else if (auto g = hc->asGaussian()) { | ||
result = addGaussian(result, g); | ||
} else { | ||
// Has to be discrete. | ||
// TODO(dellaert): in C++20, we can use std::visit. | ||
continue; | ||
} | ||
} else if (std::dynamic_pointer_cast<DiscreteFactor>(f)) { | ||
// Don't do anything for discrete-only factors | ||
// since we want to evaluate continuous values only. | ||
continue; | ||
} else { | ||
// We need to handle the case where the object is actually an | ||
// BayesTreeOrphanWrapper! | ||
throwRuntimeError("HybridBayesNet::assembleTree", f); | ||
} | ||
} | ||
|
||
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) { | ||
return std::make_pair(gbn, 0.0); | ||
}); | ||
return resultTree; | ||
} | ||
|
||
/* ************************************************************************* */ | ||
AlgebraicDecisionTree<Key> HybridBayesNet::model_selection() const { | ||
/* | ||
To perform model selection, we need: | ||
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) | ||
|
||
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) | ||
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) | ||
= exp(log(q) - log(k)) = exp(-error - log(k)) | ||
= exp(-(error + log(k))), | ||
where error is computed at the corresponding MAP point, gbn.error(mu). | ||
|
||
So we compute (error + log(k)) and exponentiate later | ||
*/ | ||
|
||
GaussianBayesNetValTree bnTree = assembleTree(); | ||
|
||
GaussianBayesNetValTree bn_error = bnTree.apply( | ||
[this](const Assignment<Key> &assignment, | ||
const std::pair<GaussianBayesNet, double> &gbnAndValue) { | ||
// Compute the X* of each assignment | ||
VectorValues mu = gbnAndValue.first.optimize(); | ||
|
||
// mu is empty if gbn had nullptrs | ||
if (mu.size() == 0) { | ||
return std::make_pair(gbnAndValue.first, | ||
std::numeric_limits<double>::max()); | ||
} | ||
|
||
// Compute the error for X* and the assignment | ||
double error = | ||
this->error(HybridValues(mu, DiscreteValues(assignment))); | ||
|
||
return std::make_pair(gbnAndValue.first, error); | ||
}); | ||
|
||
auto trees = unzip(bn_error); | ||
AlgebraicDecisionTree<Key> errorTree = trees.second; | ||
|
||
// Only compute logNormalizationConstant | ||
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>( | ||
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) { | ||
GaussianBayesNet gbn = gbnAndValue.first; | ||
if (gbn.size() == 0) { | ||
return 0.0; | ||
} | ||
return gbn.logNormalizationConstant(); | ||
}); | ||
|
||
// Compute model selection term (with help from ADT methods) | ||
AlgebraicDecisionTree<Key> model_selection_term = | ||
(errorTree + log_norm_constants) * -1; | ||
|
||
double max_log = model_selection_term.max(); | ||
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>( | ||
model_selection_term, | ||
[&max_log](const double &x) { return std::exp(x - max_log); }); | ||
model_selection = model_selection.normalize(model_selection.sum()); | ||
|
||
return model_selection; | ||
} | ||
|
||
/* ************************************************************************* */ | ||
HybridValues HybridBayesNet::optimize() const { | ||
// Collect all the discrete factors to compute MPE | ||
DiscreteBayesNet discrete_bn; | ||
DiscreteFactorGraph discrete_fg; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is way too large to properly understand. Please break up? |
||
|
||
// Compute model selection term | ||
AlgebraicDecisionTree<Key> model_selection_term = model_selection(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naming convention on variables. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
|
||
// Get the set of all discrete keys involved in model selection | ||
std::set<DiscreteKey> discreteKeySet; | ||
for (auto &&conditional : *this) { | ||
if (conditional->isDiscrete()) { | ||
discrete_bn.push_back(conditional->asDiscrete()); | ||
discrete_fg.push_back(conditional->asDiscrete()); | ||
} else { | ||
if (conditional->isContinuous()) { | ||
/* | ||
If we are here, it means there are no discrete variables in | ||
the Bayes net (due to strong elimination ordering). | ||
This is a continuous-only problem hence model selection doesn't matter. | ||
*/ | ||
|
||
} else if (conditional->isHybrid()) { | ||
auto gm = conditional->asMixture(); | ||
// Include the discrete keys | ||
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(), | ||
std::inserter(discreteKeySet, discreteKeySet.end())); | ||
} | ||
} | ||
} | ||
|
||
// Only add model_selection if we have discrete keys | ||
if (discreteKeySet.size() > 0) { | ||
discrete_fg.push_back(DecisionTreeFactor( | ||
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), | ||
model_selection_term)); | ||
} | ||
|
||
// Solve for the MPE | ||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); | ||
DiscreteValues mpe = discrete_fg.optimize(); | ||
|
||
// Given the MPE, compute the optimal continuous values. | ||
return HybridValues(optimize(mpe), mpe); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,6 +118,29 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | |
return evaluate(values); | ||
} | ||
|
||
/** | ||
* @brief Assemble a DecisionTree of (GaussianBayesNet, double) leaves for | ||
* each discrete assignment. | ||
* The included double value is used to make | ||
* constructing the model selection term cleaner and more efficient. | ||
* | ||
* @return GaussianBayesNetValTree | ||
*/ | ||
GaussianBayesNetValTree assembleTree() const; | ||
|
||
/* | ||
Perform the integration of L(X;M,Z)P(X|M) | ||
which is the model selection term. | ||
|
||
By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), | ||
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of | ||
the joint Gaussian distribution. | ||
|
||
This can be computed by multiplying all the exponentiated errors | ||
of each of the conditionals. | ||
*/ | ||
AlgebraicDecisionTree<Key> model_selection() const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naming convention: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! Hopefully the improved docstring is better. |
||
|
||
/** | ||
* @brief Solve the HybridBayesNet by first computing the MPE of all the | ||
* discrete variables and then optimizing the continuous variables based on | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests?