Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common errorTree method and its use in HybridGaussianFactorGraph #1837

Merged
merged 11 commits into from
Sep 22, 2024
16 changes: 16 additions & 0 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const {
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not in HybridFactor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HybridFactor doesn't have the asGaussian, asHybrid and asDiscrete methods.

const VectorValues &values) const {
if (auto gc = asGaussian()) {
return AlgebraicDecisionTree<Key>(gc->error(values));
}
if (auto gm = asHybrid()) {
return gm->errorTree(values);
}
if (auto dc = asDiscrete()) {
return AlgebraicDecisionTree<Key>(0.0);
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
Expand Down
10 changes: 10 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;

/**
* @brief Compute error of the HybridConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals involved, and leaf values as the error.
*/
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;

/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;

Expand Down
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }

/// Virtual class to compute tree of linear errors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linear?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This linear because it only accepts VectorValues right? This is also why its definition in HybridNonlinearFactor throws a runtime error.

virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;

/// @}

private:
Expand Down
10 changes: 5 additions & 5 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);

/**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
* conditionals. The DecisionTree-based constructor is preferred over this
* one.
* @brief Make a Hybrid Gaussian Conditional from
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
Expand Down Expand Up @@ -208,8 +208,8 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const override;

/**
* @brief Compute the logProbability of this hybrid Gaussian conditional.
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const override;

/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
Expand Down
36 changes: 10 additions & 26 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(

// Logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
// We take negative of the logNormalizationConstant `log(k)`
// to get `log(1/k) = log(\sqrt{|2πΣ|})`.
return -factor->error(kEmpty) - conditional->logNormalizationConstant();
};

Expand Down Expand Up @@ -539,36 +539,20 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

// Iterate over each factor.
for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;

auto f = factor;
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
f = hc->inner();
}

if (auto hybridGaussianCond =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Check for HybridFactor, and call errorTree
error_tree = error_tree + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// Skip discrete factors
continue;
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
// Everything else is a continuous only factor
HybridValues hv(continuousValues, DiscreteValues());
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
}
}

return error_tree;
}

Expand Down
7 changes: 7 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;

/// HybridFactor method implementation. Should not be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. So then why does it exist in the base class. red flag. Should maybe only exist in Gaussian branch.

Copy link
Collaborator Author

@varunagrawal varunagrawal Sep 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is primarily because of HybridConditional.

The class hierarchy is

  • HybridFactor
    • HybridConditional
    • HybridGaussianFactor
      • HybridGaussianConditional

HybridConditional acts as a common container class for GaussianConditional, DiscreteConditional and HybridGaussianConditional. By defining errorTree in HybridFactor, I can just check for HybridFactor. Otherwise, I'll have to check for HybridConditional and HybridGaussianFactor separately which I felt defeated the purpose of this base class method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is a different hierarchy than the non-hybrid case then:

class GTSAM_EXPORT GaussianConditional :
    public JacobianFactor,
    public Conditional<JacobianFactor, GaussianConditional>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry there was a rendering error in my previous comment.
It is the same hierarchy as GaussianConditional since HybridGaussianConditional inherits from HybridGaussianFactor, however HybridConditional complicates things (or at least makes it messier).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would break if you only have errorTree in HybridGaussianFactor? It would be available in HGC. All done?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a check needed? We would never call errorTree on HybridConditional, as it only makes sense for Gaussian variants, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A HybridConditional can have a GaussianConditional or a HybridGaussianConditional as its underlying conditional.

Remember that HybridConditional is a container based on type erasure and not a base class for HybridGaussianConditional.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. But then why would we need an errorTree function for HybridConditional. If it can be either, then a method providing VectorValues does not make sense to me. Where would that be used? Is it used???

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in an incremental sense, when we add the conditionals from a HybridBayesNet to the HybridGaussianFactorGraph, the factor graph now has HybridGaussianFactor (and potential HybridGaussianConditionals) and HybridConditionals as well.

The alternative could be to unwrap the conditionals from the HybridConditional before adding it to the graph.

I am open to suggestions about this since this was nontrivial for me. I think checking for HybridGaussianFactor and HybridConditional separately makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and another case: if we go from a HybridBayesNet to a HybridGaussianFactorGraph, then all the factors are added as HybridConditionals which are represented as HybridFactors. This causes DiscreteFactors to be considered in the errorTree computations as well. :(

AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const override {
throw std::runtime_error(
"HybridNonlinearFactor::error does not take VectorValues.");
}

public:
HybridNonlinearFactor() = default;

Expand Down
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*/
std::shared_ptr<HybridGaussianFactorGraph> linearize(
const Values& continuousValues) const;

/// Expose error(const HybridValues&) method.
using Base::error;

/// @}
};

Expand Down
52 changes: 42 additions & 10 deletions gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using symbol_shorthand::X;
* Test that any linearizedFactorGraph gaussian factors are appended to the
* existing gaussian factor graph in the hybrid factor graph.
*/
TEST(HybridFactorGraph, GaussianFactorGraph) {
TEST(HybridNonlinearFactorGraph, GaussianFactorGraph) {
HybridNonlinearFactorGraph fg;

// Add a simple prior factor to the nonlinear factor graph
Expand Down Expand Up @@ -181,7 +181,7 @@ TEST(HybridGaussianFactorGraph, HybridNonlinearFactor) {
/*****************************************************************************
* Test push_back on HFG makes the correct distinction.
*/
TEST(HybridFactorGraph, PushBack) {
TEST(HybridNonlinearFactorGraph, PushBack) {
HybridNonlinearFactorGraph fg;

auto nonlinearFactor = std::make_shared<BetweenFactor<double>>();
Expand Down Expand Up @@ -240,7 +240,7 @@ TEST(HybridFactorGraph, PushBack) {
/****************************************************************************
* Test construction of switching-like hybrid factor graph.
*/
TEST(HybridFactorGraph, Switching) {
TEST(HybridNonlinearFactorGraph, Switching) {
Switching self(3);

EXPECT_LONGS_EQUAL(7, self.nonlinearFactorGraph.size());
Expand All @@ -250,7 +250,7 @@ TEST(HybridFactorGraph, Switching) {
/****************************************************************************
* Test linearization on a switching-like hybrid factor graph.
*/
TEST(HybridFactorGraph, Linearization) {
TEST(HybridNonlinearFactorGraph, Linearization) {
Switching self(3);

// Linearize here:
Expand All @@ -263,7 +263,7 @@ TEST(HybridFactorGraph, Linearization) {
/****************************************************************************
* Test elimination tree construction
*/
TEST(HybridFactorGraph, EliminationTree) {
TEST(HybridNonlinearFactorGraph, EliminationTree) {
Switching self(3);

// Create ordering.
Expand Down Expand Up @@ -372,7 +372,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
/****************************************************************************
* Test partial elimination
*/
TEST(HybridFactorGraph, Partial_Elimination) {
TEST(HybridNonlinearFactorGraph, Partial_Elimination) {
Switching self(3);

auto linearizedFactorGraph = self.linearizedFactorGraph;
Expand Down Expand Up @@ -401,7 +401,39 @@ TEST(HybridFactorGraph, Partial_Elimination) {
EXPECT(remainingFactorGraph->at(2)->keys() == KeyVector({M(0), M(1)}));
}

TEST(HybridFactorGraph, PrintErrors) {
/* ****************************************************************************/
TEST(HybridNonlinearFactorGraph, Error) {
Switching self(3);
HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph;

{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 0}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(152.791759469, fg.error(values), 1e-9);
}
{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 1}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(151.598612289, fg.error(values), 1e-9);
}
{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 0}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(151.703972804, fg.error(values), 1e-9);
}
{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 1}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(151.609437912, fg.error(values), 1e-9);
}
}

/* ****************************************************************************/
TEST(HybridNonlinearFactorGraph, PrintErrors) {
Switching self(3);

// Get nonlinear factor graph and add linear factors to be holistic
Expand All @@ -424,7 +456,7 @@ TEST(HybridFactorGraph, PrintErrors) {
/****************************************************************************
* Test full elimination
*/
TEST(HybridFactorGraph, Full_Elimination) {
TEST(HybridNonlinearFactorGraph, Full_Elimination) {
Switching self(3);

auto linearizedFactorGraph = self.linearizedFactorGraph;
Expand Down Expand Up @@ -492,7 +524,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
/****************************************************************************
* Test printing
*/
TEST(HybridFactorGraph, Printing) {
TEST(HybridNonlinearFactorGraph, Printing) {
Switching self(3);

auto linearizedFactorGraph = self.linearizedFactorGraph;
Expand Down Expand Up @@ -784,7 +816,7 @@ conditional 2: Hybrid P( x2 | m0 m1)
* The issue arises if we eliminate a landmark variable first since it is not
* connected to a HybridFactor.
*/
TEST(HybridFactorGraph, DefaultDecisionTree) {
TEST(HybridNonlinearFactorGraph, DefaultDecisionTree) {
HybridNonlinearFactorGraph fg;

// Add a prior on pose x0 at the origin.
Expand Down
Loading