Skip to content

Commit df0ff8a

Browse files
committed
Merge branch 'develop' into hybrid-cleanup
2 parents 6c9f7ec + 2c140df commit df0ff8a

25 files changed

+190
-166
lines changed

gtsam/hybrid/HybridFactor.cpp

+12-14
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,20 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
5050

5151
/* ************************************************************************ */
5252
HybridFactor::HybridFactor(const KeyVector &keys)
53-
: Base(keys),
54-
category_(HybridCategory::Continuous),
55-
continuousKeys_(keys) {}
53+
: Base(keys), category_(Category::Continuous), continuousKeys_(keys) {}
5654

5755
/* ************************************************************************ */
58-
HybridCategory GetCategory(const KeyVector &continuousKeys,
59-
const DiscreteKeys &discreteKeys) {
56+
HybridFactor::Category GetCategory(const KeyVector &continuousKeys,
57+
const DiscreteKeys &discreteKeys) {
6058
if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) {
61-
return HybridCategory::Discrete;
59+
return HybridFactor::Category::Discrete;
6260
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) {
63-
return HybridCategory::Continuous;
61+
return HybridFactor::Category::Continuous;
6462
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() != 0)) {
65-
return HybridCategory::Hybrid;
63+
return HybridFactor::Category::Hybrid;
6664
} else {
6765
// Case where we have no keys. Should never happen.
68-
return HybridCategory::None;
66+
return HybridFactor::Category::None;
6967
}
7068
}
7169

@@ -80,7 +78,7 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys,
8078
/* ************************************************************************ */
8179
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
8280
: Base(CollectKeys({}, discreteKeys)),
83-
category_(HybridCategory::Discrete),
81+
category_(Category::Discrete),
8482
discreteKeys_(discreteKeys),
8583
continuousKeys_({}) {}
8684

@@ -97,16 +95,16 @@ void HybridFactor::print(const std::string &s,
9795
const KeyFormatter &formatter) const {
9896
std::cout << (s.empty() ? "" : s + "\n");
9997
switch (category_) {
100-
case HybridCategory::Continuous:
98+
case Category::Continuous:
10199
std::cout << "Continuous ";
102100
break;
103-
case HybridCategory::Discrete:
101+
case Category::Discrete:
104102
std::cout << "Discrete ";
105103
break;
106-
case HybridCategory::Hybrid:
104+
case Category::Hybrid:
107105
std::cout << "Hybrid ";
108106
break;
109-
case HybridCategory::None:
107+
case Category::None:
110108
std::cout << "None ";
111109
break;
112110
}

gtsam/hybrid/HybridFactor.h

+8-7
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
4141
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
4242
const DiscreteKeys &key2);
4343

44-
/// Enum to help with categorizing hybrid factors.
45-
enum class HybridCategory { None, Discrete, Continuous, Hybrid };
46-
4744
/**
4845
* Base class for *truly* hybrid probabilistic factors
4946
*
@@ -55,9 +52,13 @@ enum class HybridCategory { None, Discrete, Continuous, Hybrid };
5552
* @ingroup hybrid
5653
*/
5754
class GTSAM_EXPORT HybridFactor : public Factor {
55+
public:
56+
/// Enum to help with categorizing hybrid factors.
57+
enum class Category { None, Discrete, Continuous, Hybrid };
58+
5859
private:
5960
/// Record what category of HybridFactor this is.
60-
HybridCategory category_ = HybridCategory::None;
61+
Category category_ = Category::None;
6162

6263
protected:
6364
// Set of DiscreteKeys for this factor.
@@ -118,13 +119,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
118119
/// @{
119120

120121
/// True if this is a factor of discrete variables only.
121-
bool isDiscrete() const { return category_ == HybridCategory::Discrete; }
122+
bool isDiscrete() const { return category_ == Category::Discrete; }
122123

123124
/// True if this is a factor of continuous variables only.
124-
bool isContinuous() const { return category_ == HybridCategory::Continuous; }
125+
bool isContinuous() const { return category_ == Category::Continuous; }
125126

126127
/// True is this is a Discrete-Continuous factor.
127-
bool isHybrid() const { return category_ == HybridCategory::Hybrid; }
128+
bool isHybrid() const { return category_ == Category::Hybrid; }
128129

129130
/// Return the number of continuous variables in this factor.
130131
size_t nrContinuous() const { return continuousKeys_.size(); }

gtsam/hybrid/HybridGaussianConditional.cpp

+3-12
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,14 @@ HybridGaussianConditional::conditionals() const {
5555
return conditionals_;
5656
}
5757

58-
/* *******************************************************************************/
59-
HybridGaussianConditional::HybridGaussianConditional(
60-
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
61-
DiscreteKeys &&discreteParents,
62-
std::vector<GaussianConditional::shared_ptr> &&conditionals)
63-
: HybridGaussianConditional(continuousFrontals, continuousParents,
64-
discreteParents,
65-
Conditionals(discreteParents, conditionals)) {}
66-
6758
/* *******************************************************************************/
6859
HybridGaussianConditional::HybridGaussianConditional(
6960
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
70-
const DiscreteKeys &discreteParents,
61+
const DiscreteKey &discreteParent,
7162
const std::vector<GaussianConditional::shared_ptr> &conditionals)
7263
: HybridGaussianConditional(continuousFrontals, continuousParents,
73-
discreteParents,
74-
Conditionals(discreteParents, conditionals)) {}
64+
DiscreteKeys{discreteParent},
65+
Conditionals({discreteParent}, conditionals)) {}
7566

7667
/* *******************************************************************************/
7768
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be

gtsam/hybrid/HybridGaussianConditional.h

+6-17
Original file line numberDiff line numberDiff line change
@@ -107,29 +107,18 @@ class GTSAM_EXPORT HybridGaussianConditional
107107
const Conditionals &conditionals);
108108

109109
/**
110-
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
110+
* @brief Make a Gaussian Mixture from a vector of Gaussian conditionals.
111+
* The DecisionTree-based constructor is preferred over this one.
111112
*
112113
* @param continuousFrontals The continuous frontal variables
113114
* @param continuousParents The continuous parent variables
114-
* @param discreteParents Discrete parents variables
115-
* @param conditionals List of conditionals
116-
*/
117-
HybridGaussianConditional(
118-
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
119-
DiscreteKeys &&discreteParents,
120-
std::vector<GaussianConditional::shared_ptr> &&conditionals);
121-
122-
/**
123-
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
124-
*
125-
* @param continuousFrontals The continuous frontal variables
126-
* @param continuousParents The continuous parent variables
127-
* @param discreteParents Discrete parents variables
128-
* @param conditionals List of conditionals
115+
* @param discreteParent Single discrete parent variable
116+
* @param conditionals Vector of conditionals with the same size as the
117+
* cardinality of the discrete parent.
129118
*/
130119
HybridGaussianConditional(
131120
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
132-
const DiscreteKeys &discreteParents,
121+
const DiscreteKey &discreteParent,
133122
const std::vector<GaussianConditional::shared_ptr> &conditionals);
134123

135124
/// @}

gtsam/hybrid/HybridGaussianFactor.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ HybridGaussianFactor::Factors augment(
4242
const HybridGaussianFactor::FactorValuePairs &factors) {
4343
// Find the minimum value so we can "proselytize" to positive values.
4444
// Done because we can't have sqrt of negative numbers.
45-
auto unzipped_pair = unzip(factors);
46-
const HybridGaussianFactor::Factors gaussianFactors = unzipped_pair.first;
47-
const AlgebraicDecisionTree<Key> valueTree = unzipped_pair.second;
45+
HybridGaussianFactor::Factors gaussianFactors;
46+
AlgebraicDecisionTree<Key> valueTree;
47+
std::tie(gaussianFactors, valueTree) = unzip(factors);
48+
49+
// Normalize
4850
double min_value = valueTree.min();
4951
AlgebraicDecisionTree<Key> values =
5052
valueTree.apply([&min_value](double n) { return n - min_value; });

gtsam/hybrid/HybridGaussianFactor.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
9696
* GaussianFactor shared pointers.
9797
*
9898
* @param continuousKeys Vector of keys for continuous factors.
99-
* @param discreteKeys Vector of discrete keys.
99+
* @param discreteKey The discrete key to index each component.
100100
* @param factors Vector of gaussian factor shared pointers
101-
* and arbitrary scalars.
101+
* and arbitrary scalars. Same size as the cardinality of discreteKey.
102102
*/
103103
HybridGaussianFactor(const KeyVector &continuousKeys,
104-
const DiscreteKeys &discreteKeys,
104+
const DiscreteKey &discreteKey,
105105
const std::vector<GaussianFactorValuePair> &factors)
106-
: HybridGaussianFactor(continuousKeys, discreteKeys,
107-
FactorValuePairs(discreteKeys, factors)) {}
106+
: HybridGaussianFactor(continuousKeys, {discreteKey},
107+
FactorValuePairs({discreteKey}, factors)) {}
108108

109109
/// @}
110110
/// @name Testable

gtsam/hybrid/HybridNonlinearFactor.h

+7-6
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,15 @@ class HybridNonlinearFactor : public HybridFactor {
8989
* @tparam FACTOR The type of the factor shared pointers being passed in.
9090
* Will be typecast to NonlinearFactor shared pointers.
9191
* @param keys Vector of keys for continuous factors.
92-
* @param discreteKeys Vector of discrete keys.
92+
* @param discreteKey The discrete key indexing each component factor.
9393
* @param factors Vector of nonlinear factor and scalar pairs.
94+
* Same size as the cardinality of discreteKey.
9495
*/
9596
template <typename FACTOR>
9697
HybridNonlinearFactor(
97-
const KeyVector& keys, const DiscreteKeys& discreteKeys,
98+
const KeyVector& keys, const DiscreteKey& discreteKey,
9899
const std::vector<std::pair<std::shared_ptr<FACTOR>, double>>& factors)
99-
: Base(keys, discreteKeys) {
100+
: Base(keys, {discreteKey}) {
100101
std::vector<NonlinearFactorValuePair> nonlinear_factors;
101102
KeySet continuous_keys_set(keys.begin(), keys.end());
102103
KeySet factor_keys_set;
@@ -112,7 +113,7 @@ class HybridNonlinearFactor : public HybridFactor {
112113
"Factors passed into HybridNonlinearFactor need to be nonlinear!");
113114
}
114115
}
115-
factors_ = Factors(discreteKeys, nonlinear_factors);
116+
factors_ = Factors({discreteKey}, nonlinear_factors);
116117

117118
if (continuous_keys_set != factor_keys_set) {
118119
throw std::runtime_error(
@@ -134,7 +135,7 @@ class HybridNonlinearFactor : public HybridFactor {
134135
auto errorFunc =
135136
[continuousValues](const std::pair<sharedFactor, double>& f) {
136137
auto [factor, val] = f;
137-
return factor->error(continuousValues) + (0.5 * val * val);
138+
return factor->error(continuousValues) + (0.5 * val);
138139
};
139140
DecisionTree<Key, double> result(factors_, errorFunc);
140141
return result;
@@ -153,7 +154,7 @@ class HybridNonlinearFactor : public HybridFactor {
153154
auto [factor, val] = factors_(discreteValues);
154155
// Compute the error for the selected factor
155156
const double factorError = factor->error(continuousValues);
156-
return factorError + (0.5 * val * val);
157+
return factorError + (0.5 * val);
157158
}
158159

159160
/**

gtsam/hybrid/hybrid.i

+8-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ virtual class HybridConditional {
7676
class HybridGaussianFactor : gtsam::HybridFactor {
7777
HybridGaussianFactor(
7878
const gtsam::KeyVector& continuousKeys,
79-
const gtsam::DiscreteKeys& discreteKeys,
79+
const gtsam::DiscreteKey& discreteKey,
8080
const std::vector<std::pair<gtsam::GaussianFactor::shared_ptr, double>>&
8181
factorsList);
8282

@@ -91,8 +91,12 @@ class HybridGaussianConditional : gtsam::HybridFactor {
9191
const gtsam::KeyVector& continuousFrontals,
9292
const gtsam::KeyVector& continuousParents,
9393
const gtsam::DiscreteKeys& discreteParents,
94-
const std::vector<gtsam::GaussianConditional::shared_ptr>&
95-
conditionalsList);
94+
const gtsam::HybridGaussianConditional::Conditionals& conditionals);
95+
HybridGaussianConditional(
96+
const gtsam::KeyVector& continuousFrontals,
97+
const gtsam::KeyVector& continuousParents,
98+
const gtsam::DiscreteKey& discreteParent,
99+
const std::vector<gtsam::GaussianConditional::shared_ptr>& conditionals);
96100

97101
gtsam::HybridGaussianFactor* likelihood(
98102
const gtsam::VectorValues& frontals) const;
@@ -248,7 +252,7 @@ class HybridNonlinearFactor : gtsam::HybridFactor {
248252
bool normalized = false);
249253

250254
HybridNonlinearFactor(
251-
const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
255+
const gtsam::KeyVector& keys, const gtsam::DiscreteKey& discreteKey,
252256
const std::vector<std::pair<gtsam::NonlinearFactor*, double>>& factors,
253257
bool normalized = false);
254258

gtsam/hybrid/tests/Switching.h

+12-11
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,16 @@ inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain(
5757

5858
// keyFunc(1) to keyFunc(n+1)
5959
for (size_t t = 1; t < n; t++) {
60-
std::vector<GaussianFactorValuePair> components = {
61-
{std::make_shared<JacobianFactor>(keyFunc(t), I_3x3, keyFunc(t + 1),
62-
I_3x3, Z_3x1),
63-
0.0},
64-
{std::make_shared<JacobianFactor>(keyFunc(t), I_3x3, keyFunc(t + 1),
65-
I_3x3, Vector3::Ones()),
66-
0.0}};
67-
hfg.add(HybridGaussianFactor({keyFunc(t), keyFunc(t + 1)},
68-
{{dKeyFunc(t), 2}}, components));
60+
DiscreteKeys dKeys{{dKeyFunc(t), 2}};
61+
HybridGaussianFactor::FactorValuePairs components(
62+
dKeys, {{std::make_shared<JacobianFactor>(keyFunc(t), I_3x3,
63+
keyFunc(t + 1), I_3x3, Z_3x1),
64+
0.0},
65+
{std::make_shared<JacobianFactor>(
66+
keyFunc(t), I_3x3, keyFunc(t + 1), I_3x3, Vector3::Ones()),
67+
0.0}});
68+
hfg.add(
69+
HybridGaussianFactor({keyFunc(t), keyFunc(t + 1)}, dKeys, components));
6970

7071
if (t > 1) {
7172
hfg.add(DecisionTreeFactor({{dKeyFunc(t - 1), 2}, {dKeyFunc(t), 2}},
@@ -167,8 +168,8 @@ struct Switching {
167168
components.push_back(
168169
{std::dynamic_pointer_cast<NonlinearFactor>(f), 0.0});
169170
}
170-
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(
171-
keys, DiscreteKeys{modes[k]}, components);
171+
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(keys, modes[k],
172+
components);
172173
}
173174

174175
// Add measurement factors

gtsam/hybrid/tests/TinyHybridExample.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ inline HybridBayesNet createHybridBayesNet(size_t num_measurements = 1,
4343
// Create Gaussian mixture z_i = x0 + noise for each measurement.
4444
for (size_t i = 0; i < num_measurements; i++) {
4545
const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode;
46+
std::vector<GaussianConditional::shared_ptr> conditionals{
47+
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5),
48+
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3)};
4649
bayesNet.emplace_shared<HybridGaussianConditional>(
47-
KeyVector{Z(i)}, KeyVector{X(0)}, DiscreteKeys{mode_i},
48-
std::vector{GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0),
49-
Z_1x1, 0.5),
50-
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0),
51-
Z_1x1, 3)});
50+
KeyVector{Z(i)}, KeyVector{X(0)}, mode_i, conditionals);
5251
}
5352

5453
// Create prior on X(0).

gtsam/hybrid/tests/testHybridBayesNet.cpp

+8-9
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ TEST(HybridBayesNet, evaluateHybrid) {
108108
HybridBayesNet bayesNet;
109109
bayesNet.push_back(continuousConditional);
110110
bayesNet.emplace_shared<HybridGaussianConditional>(
111-
KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia},
111+
KeyVector{X(1)}, KeyVector{}, Asia,
112112
std::vector{conditional0, conditional1});
113113
bayesNet.emplace_shared<DiscreteConditional>(Asia, "99/1");
114114

@@ -169,7 +169,7 @@ TEST(HybridBayesNet, Error) {
169169
X(1), Vector1::Constant(2), I_1x1, model1);
170170

171171
auto gm = std::make_shared<HybridGaussianConditional>(
172-
KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia},
172+
KeyVector{X(1)}, KeyVector{}, Asia,
173173
std::vector{conditional0, conditional1});
174174
// Create hybrid Bayes net.
175175
HybridBayesNet bayesNet;
@@ -383,17 +383,16 @@ TEST(HybridBayesNet, Sampling) {
383383
HybridNonlinearFactorGraph nfg;
384384

385385
auto noise_model = noiseModel::Diagonal::Sigmas(Vector1(1.0));
386+
nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model);
387+
386388
auto zero_motion =
387389
std::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
388390
auto one_motion =
389391
std::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
390-
391-
DiscreteKeys discreteKeys{DiscreteKey(M(0), 2)};
392-
HybridNonlinearFactor::Factors factors(
393-
discreteKeys, {{zero_motion, 0.0}, {one_motion, 0.0}});
394-
nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model);
395-
nfg.emplace_shared<HybridNonlinearFactor>(KeyVector{X(0), X(1)}, discreteKeys,
396-
factors);
392+
nfg.emplace_shared<HybridNonlinearFactor>(
393+
KeyVector{X(0), X(1)}, DiscreteKey(M(0), 2),
394+
std::vector<NonlinearFactorValuePair>{{zero_motion, 0.0},
395+
{one_motion, 0.0}});
397396

398397
DiscreteKey mode(M(0), 2);
399398
nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");

0 commit comments

Comments
 (0)