-
Notifications
You must be signed in to change notification settings - Fork 805
/
Copy pathGaussianMixtureFactor.h
175 lines (145 loc) · 5.28 KB
/
GaussianMixtureFactor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file GaussianMixtureFactor.h
* @brief A set of GaussianFactors, indexed by a set of discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Frank Dellaert
* @date Mar 12, 2022
*/
#pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
class HybridValues;
class DiscreteValues;
class VectorValues;
/**
* @brief Implementation of a discrete conditional mixture factor.
* Implements a joint discrete-continuous factor where the discrete variable
* serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement.
*
* Represents the underlying Gaussian mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution.
*
* @ingroup hybrid
*/
class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
public:
using Base = HybridFactor;
using This = GaussianMixtureFactor;
using shared_ptr = std::shared_ptr<This>;
using sharedFactor = std::shared_ptr<GaussianFactor>;
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, sharedFactor>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
public:
/// @name Constructors
/// @{
/// Default constructor, mainly for serialization.
GaussianMixtureFactor() = default;
/**
* @brief Construct a new Gaussian mixture factor.
*
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities.
* @param factors The decision tree of Gaussian factors stored as the mixture
* density.
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors);
/**
* @brief Construct a new GaussianMixtureFactor object using a vector of
* GaussianFactor shared pointers.
*
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys.
* @param factors Vector of gaussian factor shared pointers.
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {}
/// @}
/// @name Testable
/// @{
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void print(
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard API
/// @{
/// Get factor at a given discrete assignment.
sharedFactor operator()(const DiscreteValues &assignment) const;
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
* maintaining the original tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs indexed by the
* variables.
* @return Sum
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
* @return double
*/
double error(const HybridValues &values) const override;
/// Getter for GaussianFactor decision tree
const Factors &factors() const { return factors_; }
/// Add MixtureFactor to a Sum, syntactic sugar.
friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
return sum;
}
/// @}
private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(factors_);
}
#endif
};
// traits
template <>
struct traits<GaussianMixtureFactor> : public Testable<GaussianMixtureFactor> {
};
} // namespace gtsam