Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common errorTree method and its use in HybridGaussianFactorGraph #1837

Merged
merged 11 commits into from
Sep 22, 2024
16 changes: 16 additions & 0 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const {
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not in HybridFactor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HybridFactor doesn't have the asGaussian, asHybrid and asDiscrete methods.

const VectorValues &values) const {
if (auto gc = asGaussian()) {
return AlgebraicDecisionTree<Key>(gc->error(values));
}
if (auto gm = asHybrid()) {
return gm->errorTree(values);
}
if (auto dc = asDiscrete()) {
return AlgebraicDecisionTree<Key>(0.0);
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
Expand Down
10 changes: 10 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;

/**
* @brief Compute error of the HybridConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;

/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;

Expand Down
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }

/// Virtual class to compute tree of linear errors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linear?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This linear because it only accepts VectorValues right? This is also why its definition in HybridNonlinearFactor throws a runtime error.

virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;

/// @}

private:
Expand Down
34 changes: 0 additions & 34 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,40 +323,6 @@ AlgebraicDecisionTree<Key> HybridGaussianConditional::logProbability(
return DecisionTree<Key, double>(conditionals_, probFunc);
}

/* ************************************************************************* */
double HybridGaussianConditional::conditionalError(
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
-logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianConditional::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditionalError(conditional, continuousValues);
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
double HybridGaussianConditional::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditionalError(conditional, values.continuous());
}

/* *******************************************************************************/
double HybridGaussianConditional::logProbability(
const HybridValues &values) const {
Expand Down
47 changes: 3 additions & 44 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);

/**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
* conditionals. The DecisionTree-based constructor is preferred over this
* one.
* @brief Make a Hybrid Gaussian Conditional from
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
Expand Down Expand Up @@ -174,43 +174,6 @@ class GTSAM_EXPORT HybridGaussianConditional
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this hybrid Gaussian conditional.
*
* This requires some care, as different components may have
* different normalization constants. Let's consider p(x|y,m), where m is
* discrete. We need the error to satisfy the invariant:
*
* error(x;y,m) = K - log(probability(x;y,m))
*
* For all x,y,m. But note that K, the (log) normalization constant defined
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;

/**
* @brief Compute error of the HybridGaussianConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this hybrid Gaussian conditional.
*
Expand Down Expand Up @@ -241,10 +204,6 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

/// Helper method to compute the error of a conditional.
double conditionalError(const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const;

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
Expand Down
21 changes: 18 additions & 3 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,36 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
return {factors_, wrap};
}

/* *******************************************************************************/
double HybridGaussianFactor::potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &values) const {
// Check if valid pointer
if (gf) {
return gf->error(values);
} else {
// If not valid, pointer, it means this component was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
return this->potentiallyPrunedComponentError(gf, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
return potentiallyPrunedComponentError(gf, values.continuous());
}

} // namespace gtsam
6 changes: 5 additions & 1 deletion gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
const VectorValues &continuousValues) const override;

/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
Expand All @@ -169,6 +169,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// @}

private:
/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
Expand Down
36 changes: 10 additions & 26 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(

// Logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
// We take negative of the logNormalizationConstant `log(k)`
// to get `log(1/k) = log(\sqrt{|2πΣ|})`.
return -factor->error(kEmpty) - conditional->logNormalizationConstant();
};

Expand Down Expand Up @@ -539,36 +539,20 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

// Iterate over each factor.
for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;

auto f = factor;
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
f = hc->inner();
}

if (auto hybridGaussianCond =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Check for HybridFactor, and call errorTree
error_tree = error_tree + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// Skip discrete factors
continue;
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
// Everything else is a continuous only factor
HybridValues hv(continuousValues, DiscreteValues());
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
}
}

return error_tree;
}

Expand Down
7 changes: 7 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;

/// HybridFactor method implementation. Should not be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. So then why does it exist in the base class. red flag. Should maybe only exist in Gaussian branch.

Copy link
Collaborator Author

@varunagrawal varunagrawal Sep 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is primarily because of HybridConditional.

The class hierarchy is

  • HybridFactor
    • HybridConditional
    • HybridGaussianFactor
      • HybridGaussianConditional

HybridConditional acts as a common container class for GaussianConditional, DiscreteConditional and HybridGaussianConditional. By defining errorTree in HybridFactor, I can just check for HybridFactor. Otherwise, I'll have to check for HybridConditional and HybridGaussianFactor separately which I felt defeated the purpose of this base class method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is a different hierarchy than the non-hybrid case then:

class GTSAM_EXPORT GaussianConditional :
    public JacobianFactor,
    public Conditional<JacobianFactor, GaussianConditional>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry there was a rendering error in my previous comment.
It is the same hierarchy as GaussianConditional since HybridGaussianConditional inherits from HybridGaussianFactor, however HybridConditional complicates things (or at least makes it messier).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would break if you only have errorTree in HybridGaussianFactor? It would be available in HGC. All done?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a check needed? We would never call errorTree on HybridConditional, as it only makes sense for Gaussian variants, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A HybridConditional can have a GaussianConditional or a HybridGaussianConditional as its underlying conditional.

Remember that HybridConditional is a container based on type erasure and not a base class for HybridGaussianConditional.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. But then why would we need an errorTree function for HybridConditional. If it can be either, then a method providing VectorValues does not make sense to me. Where would that be used? Is it used???

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in an incremental sense, when we add the conditionals from a HybridBayesNet to the HybridGaussianFactorGraph, the factor graph now has HybridGaussianFactor (and potential HybridGaussianConditionals) and HybridConditionals as well.

The alternative could be to unwrap the conditionals from the HybridConditional before adding it to the graph.

I am open to suggestions about this since this was nontrivial for me. I think checking for HybridGaussianFactor and HybridConditional separately makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and another case: if we go from a HybridBayesNet to a HybridGaussianFactorGraph, then all the factors are added as HybridConditionals which are represented as HybridFactors. This causes DiscreteFactors to be considered in the errorTree computations as well. :(

AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const override {
throw std::runtime_error(
"HybridNonlinearFactor::error does not take VectorValues.");
}

public:
HybridNonlinearFactor() = default;

Expand Down
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*/
std::shared_ptr<HybridGaussianFactorGraph> linearize(
const Values& continuousValues) const;

/// Expose error(const HybridValues&) method.
using Base::error;

/// @}
};

Expand Down
49 changes: 49 additions & 0 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,55 @@ TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) {
EXPECT(assert_equal(expected, errorTree, 1e-9));
}

/* ****************************************************************************/
// Test hybrid gaussian factor graph errorTree during
// incremental operation
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
Switching s(4);

HybridGaussianFactorGraph graph;
graph.push_back(s.linearizedFactorGraph.at(0)); // f(X0)
graph.push_back(s.linearizedFactorGraph.at(1)); // f(X0, X1, M0)
graph.push_back(s.linearizedFactorGraph.at(2)); // f(X1, X2, M1)
graph.push_back(s.linearizedFactorGraph.at(4)); // f(X1)
graph.push_back(s.linearizedFactorGraph.at(5)); // f(X2)
graph.push_back(s.linearizedFactorGraph.at(7)); // f(M0)
graph.push_back(s.linearizedFactorGraph.at(8)); // f(M0, M1)

HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.errorTree(delta.continuous());

std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.99985581, 0.4902432, 0.51936941,
0.0097568009};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);

// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));

graph = HybridGaussianFactorGraph();
graph.push_back(*hybridBayesNet);
graph.push_back(s.linearizedFactorGraph.at(3)); // f(X2, X3, M2)
graph.push_back(s.linearizedFactorGraph.at(6)); // f(X3)

hybridBayesNet = graph.eliminateSequential();
EXPECT_LONGS_EQUAL(7, hybridBayesNet->size());

delta = hybridBayesNet->optimize();
auto error_tree2 = graph.errorTree(delta.continuous());

discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
AlgebraicDecisionTree<Key> expected_error2(discrete_keys, leaves);

// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
}

/* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.
Expand Down
Loading
Loading