Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AlgebraicDecisionTree Helpers #1696

Merged
merged 54 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
7695fd6
Improved HybridBayesNet::optimize with proper model selection
varunagrawal Nov 20, 2023
39f7ac2
handle nullptrsin GaussianMixture::error
varunagrawal Nov 20, 2023
c374a26
nicer HybridBayesNet::optimize with normalized errors
varunagrawal Nov 20, 2023
ed5ef66
Merge branch 'hybrid-printerrors' into model-selection-integration
varunagrawal Nov 27, 2023
50670da
HybridValues formatting
varunagrawal Dec 12, 2023
af490e9
sum and normalize helper methods for the AlgebraicDecisionTree
varunagrawal Dec 12, 2023
c004bd8
test for differing covariances
varunagrawal Dec 13, 2023
7b56c96
differing means test
varunagrawal Dec 15, 2023
e549a9b
normalize model selection term
varunagrawal Dec 15, 2023
b2638c8
max and min functions for AlgebraicDecisionTree
varunagrawal Dec 17, 2023
6f09be5
error normalization and log-sum-exp trick
varunagrawal Dec 17, 2023
3660429
handle numerical instability
varunagrawal Dec 18, 2023
07ddec5
remove stray comment
varunagrawal Dec 21, 2023
c6584f6
minor test cleanup
varunagrawal Dec 22, 2023
ebcf958
better, more correct version of model selection
varunagrawal Dec 25, 2023
1e298be
Better way of handling assignments
varunagrawal Dec 25, 2023
b4f07a0
cleaner model selection computation
varunagrawal Dec 26, 2023
6f4343c
almost working
varunagrawal Dec 26, 2023
409938f
improved model selection code
varunagrawal Dec 26, 2023
93c824c
overload == operator for GaussianBayesNet and VectorValues
varunagrawal Dec 27, 2023
b20d33d
logNormalizationConstant() for GaussianBayesNet
varunagrawal Dec 27, 2023
3a89653
helper methods in GaussianMixture for model selection
varunagrawal Dec 27, 2023
6f66d04
handle pruning in model selection
varunagrawal Dec 27, 2023
0d05810
update wrapper for LM with Ordering parameter
varunagrawal Jan 3, 2024
651f999
print logNormalizationConstant for Gaussian conditionals
varunagrawal Jan 3, 2024
114c86f
GaussianConditional wrapper for arbitrary number of keys
varunagrawal Jan 3, 2024
82e0c0d
take comment all the way
varunagrawal Jan 3, 2024
8a61c49
add model_selection method to HybridBayesNet
varunagrawal Jan 3, 2024
3ba54eb
improved docstrings
varunagrawal Jan 3, 2024
bb95cd4
remove `using std::dynamic_pointer_cast;`
varunagrawal Jan 3, 2024
6d50de8
docstring for HybridBayesNet::assembleTree
varunagrawal Jan 3, 2024
9ad7697
Merge branch 'hybrid-printerrors' into model-selection-integration
varunagrawal Jan 4, 2024
502e8cf
Merge branch 'model-selection-integration' into hybrid-lognormconstant
varunagrawal Jan 4, 2024
a80b5d4
Merge branch 'hybrid-printerrors' into model-selection-integration
varunagrawal Jan 5, 2024
0430fee
improved naming and documentation
varunagrawal Jan 7, 2024
afcb933
document return type
varunagrawal Jan 7, 2024
c5bfd52
better printing of GaussianMixtureFactor
varunagrawal Jan 12, 2024
e9e2ef9
Merge pull request #1705 from borglab/hybrid-lognormconstant
varunagrawal Feb 20, 2024
538871a
Merge branch 'develop' into model-selection-integration
varunagrawal Mar 18, 2024
1501b7c
Merge branch 'develop' into model-selection-integration
varunagrawal Jun 28, 2024
eb9ea78
Merge branch 'develop' into model-selection-integration
varunagrawal Jul 2, 2024
a9cf4a0
fix namespacing
varunagrawal Jul 3, 2024
2a080bb
Merge branch 'develop' into model-selection-integration
varunagrawal Jul 29, 2024
113a7f8
added more comments and compute GaussianMixture before tau
varunagrawal Aug 5, 2024
2430abb
test for different error values in BN from MixtureFactor
varunagrawal Aug 7, 2024
3c722ac
update GaussianMixtureFactor to record normalizers, and add unit tests
varunagrawal Aug 20, 2024
d4e5a9b
different means test both via direct factor definition and toFactorGraph
varunagrawal Aug 20, 2024
fef929f
clean up model selection
varunagrawal Aug 20, 2024
654bad7
remove model selection code
varunagrawal Aug 20, 2024
6b1d89d
fix testMixtureFactor test
varunagrawal Aug 20, 2024
6d9fc8e
undo change in GaussianMixture
varunagrawal Aug 20, 2024
fd2062b
remove changes so we can break up PR into smaller ones
varunagrawal Aug 20, 2024
cea84b8
reduce the diff even more
varunagrawal Aug 20, 2024
73d971a
unit tests for AlgebraicDecisionTree helper methods
varunagrawal Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,42 @@ namespace gtsam {
return this->apply(g, &Ring::div);
}

/// Compute sum of all values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests?

double sum() const {
double sum = 0;
auto visitor = [&](double y) { sum += y; };
this->visit(visitor);
return sum;
}

/**
* @brief Helper method to perform normalization such that all leaves in the
* tree sum to 1
*
* @param sum
* @return AlgebraicDecisionTree
*/
AlgebraicDecisionTree normalize(double sum) const {
return this->apply([&sum](const double& x) { return x / sum; });
}

/// Find the minimum values amongst all leaves
double min() const {
double min = std::numeric_limits<double>::max();
auto visitor = [&](double x) { min = x < min ? x : min; };
this->visit(visitor);
return min;
}

/// Find the maximum values amongst all leaves
double max() const {
// Get the most negative value
double max = -std::numeric_limits<double>::max();
auto visitor = [&](double x) { max = x > max ? x : max; };
this->visit(visitor);
return max;
}

/** sum out variable */
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
return this->combine(label, cardinality, &Ring::add);
Expand Down
66 changes: 62 additions & 4 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {
Expand Down Expand Up @@ -92,6 +93,34 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
return {conditionals_, wrap};
}

/* *******************************************************************************/
GaussianBayesNetTree GaussianMixture::add(
const GaussianBayesNetTree &sum) const {
using Y = GaussianBayesNet;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
if (graph2.size() == 0) {
return GaussianBayesNet();
}
result.push_back(graph2);
return result;
};
const auto tree = asGaussianBayesNetTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
if (gc) {
return GaussianBayesNet{gc};
} else {
return GaussianBayesNet();
}
};
return {conditionals_, wrap};
}

/* *******************************************************************************/
size_t GaussianMixture::nrComponents() const {
size_t total = 0;
Expand Down Expand Up @@ -316,19 +345,48 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Check if discrete keys in discrete assignment are
// present in the GaussianMixture
KeyVector dKeys = this->discreteKeys_.indices();
bool valid_assignment = false;
for (auto &&kv : values.discrete()) {
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
valid_assignment = true;
break;
}
}

// The discrete assignment is not valid so we return 0.0 erorr.
if (!valid_assignment) {
return 0.0;
}

// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
if (conditional) {
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
}

/* *******************************************************************************/
Expand Down
18 changes: 17 additions & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,17 @@ class GTSAM_EXPORT GaussianMixture
double logConstant_; ///< log of the normalization constant.

/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
* @brief Convert a DecisionTree of factors into
* a DecisionTree of Gaussian factor graphs.
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

/**
* @brief Convert a DecisionTree of conditionals into
* a DecisionTree of Gaussian Bayes nets.
*/
GaussianBayesNetTree asGaussianBayesNetTree() const;

/**
* @brief Helper function to get the pruner functor.
*
Expand Down Expand Up @@ -248,6 +255,15 @@ class GTSAM_EXPORT GaussianMixture
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;

/**
* @brief Merge the Gaussian Bayes Nets in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Bayes Nets
* @return GaussianBayesNetTree
*/
GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const;
/// @}

private:
Expand Down
158 changes: 155 additions & 3 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ static std::mt19937_64 kRandomNumberGenerator(42);

namespace gtsam {

/* ************************************************************************ */
// Throw a runtime exception for method specified in string s,
// and conditional f:
static void throwRuntimeError(const std::string &s,
const std::shared_ptr<HybridConditional> &f) {
auto &fr = *f;
throw std::runtime_error(s + " not implemented for conditional type " +
demangle(typeid(fr).name()) + ".");
}

/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Expand Down Expand Up @@ -217,18 +227,160 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}

/* ************************************************************************ */
static GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gfgTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}

/* ************************************************************************ */
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
GaussianBayesNetTree result;

for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(f)) {
result = gm->add(result);
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit.
continue;
}
} else if (std::dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors
// since we want to evaluate continuous values only.
continue;
} else {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
throwRuntimeError("HybridBayesNet::assembleTree", f);
}
}

GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) {
return std::make_pair(gbn, 0.0);
});
return resultTree;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::model_selection() const {
/*
To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))

If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k))),
where error is computed at the corresponding MAP point, gbn.error(mu).

So we compute (error + log(k)) and exponentiate later
*/

GaussianBayesNetValTree bnTree = assembleTree();

GaussianBayesNetValTree bn_error = bnTree.apply(
[this](const Assignment<Key> &assignment,
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
// Compute the X* of each assignment
VectorValues mu = gbnAndValue.first.optimize();

// mu is empty if gbn had nullptrs
if (mu.size() == 0) {
return std::make_pair(gbnAndValue.first,
std::numeric_limits<double>::max());
}

// Compute the error for X* and the assignment
double error =
this->error(HybridValues(mu, DiscreteValues(assignment)));

return std::make_pair(gbnAndValue.first, error);
});

auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second;

// Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
GaussianBayesNet gbn = gbnAndValue.first;
if (gbn.size() == 0) {
return 0.0;
}
return gbn.logNormalizationConstant();
});

// Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> model_selection_term =
(errorTree + log_norm_constants) * -1;

double max_log = model_selection_term.max();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
model_selection_term,
[&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum());

return model_selection;
}

/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn;
DiscreteFactorGraph discrete_fg;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is way too large to properly understand. Please break up?


// Compute model selection term
AlgebraicDecisionTree<Key> model_selection_term = model_selection();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming convention on variables. modelSelectionTerms? This is a global comment on this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


// Get the set of all discrete keys involved in model selection
std::set<DiscreteKey> discreteKeySet;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscrete());
discrete_fg.push_back(conditional->asDiscrete());
} else {
if (conditional->isContinuous()) {
/*
If we are here, it means there are no discrete variables in
the Bayes net (due to strong elimination ordering).
This is a continuous-only problem hence model selection doesn't matter.
*/

} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
// Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
std::inserter(discreteKeySet, discreteKeySet.end()));
}
}
}

// Only add model_selection if we have discrete keys
if (discreteKeySet.size() > 0) {
discrete_fg.push_back(DecisionTreeFactor(
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
model_selection_term));
}

// Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
DiscreteValues mpe = discrete_fg.optimize();

// Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe);
Expand Down
23 changes: 23 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,29 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
return evaluate(values);
}

/**
* @brief Assemble a DecisionTree of (GaussianBayesNet, double) leaves for
* each discrete assignment.
* The included double value is used to make
* constructing the model selection term cleaner and more efficient.
*
* @return GaussianBayesNetValTree
*/
GaussianBayesNetValTree assembleTree() const;

/*
Perform the integration of L(X;M,Z)P(X|M)
which is the model selection term.

By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.

This can be computed by multiplying all the exponentiated errors
of each of the conditionals.
*/
AlgebraicDecisionTree<Key> model_selection() const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming convention: modelSelection. Although, it's a bit of a weird name, and the comment does not help much. What is the return value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Hopefully the improved docstring is better.


/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on
Expand Down
8 changes: 8 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ class HybridValues;

/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
/// Alias for DecisionTree of GaussianBayesNets
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
/**
* Alias for DecisionTree of (GaussianBayesNet, double) pairs.
* Used for model selection in BayesNet::optimize
*/
using GaussianBayesNetValTree =
DecisionTree<Key, std::pair<GaussianBayesNet, double>>;

KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
Expand Down
Loading