-
Notifications
You must be signed in to change notification settings - Fork 805
/
Copy pathHybridGaussianFactorGraph.h
219 lines (186 loc) · 7.01 KB
/
HybridGaussianFactorGraph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridGaussianFactorGraph.h
* @brief Linearized Hybrid factor graph that uses type erasure
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
* @date Mar 11, 2022
*/
#pragma once
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
#include <functional>
#include <optional>
namespace gtsam {
// Forward declarations
class HybridGaussianFactorGraph;
class HybridConditional;
class HybridBayesNet;
class HybridEliminationTree;
class HybridBayesTree;
class HybridJunctionTree;
class DecisionTreeFactor;
class TableFactor;
class JacobianFactor;
class HybridValues;
/**
* @brief Main elimination function for HybridGaussianFactorGraph.
*
* @param factors The factor graph to eliminate.
* @param keys The elimination ordering.
* @return The conditional on the ordering keys and the remaining factors.
* @ingroup hybrid
*/
GTSAM_EXPORT
std::pair<std::shared_ptr<HybridConditional>, std::shared_ptr<Factor>>
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
GTSAM_EXPORT const Ordering
HybridOrdering(const HybridGaussianFactorGraph& graph);
/* ************************************************************************* */
template <>
struct EliminationTraits<HybridGaussianFactorGraph> {
typedef Factor FactorType; ///< Type of factors in factor graph
typedef HybridGaussianFactorGraph
FactorGraphType; ///< Type of the factor graph (e.g.
///< HybridGaussianFactorGraph)
typedef HybridConditional
ConditionalType; ///< Type of conditionals from elimination
typedef HybridBayesNet
BayesNetType; ///< Type of Bayes net from sequential elimination
typedef HybridEliminationTree
EliminationTreeType; ///< Type of elimination tree
typedef HybridBayesTree BayesTreeType; ///< Type of Bayes tree
typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function
static std::pair<std::shared_ptr<ConditionalType>,
std::shared_ptr<FactorType>>
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateHybrid(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
std::optional<std::reference_wrapper<const VariableIndex>>) {
return HybridOrdering(graph);
}
};
/**
* Hybrid Gaussian Factor Graph
* -----------------------
* This is the linearized version of a hybrid factor graph.
*
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridGaussianFactorGraph
: public HybridFactorGraph,
public EliminateableFactorGraph<HybridGaussianFactorGraph> {
protected:
/// Check if FACTOR type is derived from GaussianFactor.
template <typename FACTOR>
using IsGaussian = typename std::enable_if<
std::is_base_of<GaussianFactor, FACTOR>::value>::type;
public:
using Base = HybridFactorGraph;
using This = HybridGaussianFactorGraph; ///< this class
///< for elimination
using BaseEliminateable = EliminateableFactorGraph<This>;
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility
using Indices = KeyVector; ///< map from keys to values
/// @name Constructors
/// @{
/// @brief Default constructor.
HybridGaussianFactorGraph() = default;
/** Construct from container of factors (shared_ptr or plain objects) */
template <class CONTAINER>
explicit HybridGaussianFactorGraph(const CONTAINER& factors)
: Base(factors) {}
/**
* Implicit copy/downcast constructor to override explicit template container
* constructor. In BayesTree this is used for:
* `cachedSeparatorMarginal_.reset(*separatorMarginal)`
* */
template <class DERIVEDFACTOR>
HybridGaussianFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph)
: Base(graph) {}
/// @}
/// @name Testable
/// @{
// TODO(dellaert): customize print and equals.
// void print(
// const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
void printErrors(
const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;
// bool equals(const This& fg, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
/// Expose error(const HybridValues&) method.
using Base::error;
/**
* @brief Compute error for each discrete assignment,
* and return as a tree.
*
* Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;
/**
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @return double
*/
double probPrime(const HybridValues& values) const;
/**
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.
*
* For example, if there are two mixture factors, one with a discrete key A
* and one with a discrete key B, then the decision tree will have two levels,
* one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys.
*/
GaussianFactorGraphTree assembleGraphTree() const;
/// @}
/// Get the GaussianFactorGraph at a given discrete assignment.
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
};
} // namespace gtsam