Skip to content

Commit b895e64

Browse files
committed
Merge branch 'develop' into direct-hybrid-fg
2 parents de68aec + 2c140df commit b895e64

27 files changed

+225
-179
lines changed

gtsam/hybrid/HybridConditional.cpp

+4-11
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,9 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
2828
const DiscreteKeys &discreteFrontals,
2929
const KeyVector &continuousParents,
3030
const DiscreteKeys &discreteParents)
31-
: HybridConditional(
32-
CollectKeys(
33-
{continuousFrontals.begin(), continuousFrontals.end()},
34-
KeyVector{continuousParents.begin(), continuousParents.end()}),
35-
CollectDiscreteKeys(
36-
{discreteFrontals.begin(), discreteFrontals.end()},
37-
{discreteParents.begin(), discreteParents.end()}),
38-
continuousFrontals.size() + discreteFrontals.size()) {}
31+
: HybridConditional(CollectKeys(continuousFrontals, continuousParents),
32+
CollectDiscreteKeys(discreteFrontals, discreteParents),
33+
continuousFrontals.size() + discreteFrontals.size()) {}
3934

4035
/* ************************************************************************ */
4136
HybridConditional::HybridConditional(
@@ -56,9 +51,7 @@ HybridConditional::HybridConditional(
5651
/* ************************************************************************ */
5752
HybridConditional::HybridConditional(
5853
const std::shared_ptr<HybridGaussianConditional> &gaussianMixture)
59-
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
60-
gaussianMixture->keys().begin() +
61-
gaussianMixture->nrContinuous()),
54+
: BaseFactor(gaussianMixture->continuousKeys(),
6255
gaussianMixture->discreteKeys()),
6356
BaseConditional(gaussianMixture->nrFrontals()) {
6457
inner_ = gaussianMixture;

gtsam/hybrid/HybridFactor.cpp

+35-11
Original file line numberDiff line numberDiff line change
@@ -50,41 +50,65 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
5050

5151
/* ************************************************************************ */
5252
HybridFactor::HybridFactor(const KeyVector &keys)
53-
: Base(keys), isContinuous_(true), continuousKeys_(keys) {}
53+
: Base(keys), category_(Category::Continuous), continuousKeys_(keys) {}
54+
55+
/* ************************************************************************ */
56+
HybridFactor::Category GetCategory(const KeyVector &continuousKeys,
57+
const DiscreteKeys &discreteKeys) {
58+
if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) {
59+
return HybridFactor::Category::Discrete;
60+
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) {
61+
return HybridFactor::Category::Continuous;
62+
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() != 0)) {
63+
return HybridFactor::Category::Hybrid;
64+
} else {
65+
// Case where we have no keys. Should never happen.
66+
return HybridFactor::Category::None;
67+
}
68+
}
5469

5570
/* ************************************************************************ */
5671
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
5772
const DiscreteKeys &discreteKeys)
5873
: Base(CollectKeys(continuousKeys, discreteKeys)),
59-
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
60-
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
61-
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
74+
category_(GetCategory(continuousKeys, discreteKeys)),
6275
discreteKeys_(discreteKeys),
6376
continuousKeys_(continuousKeys) {}
6477

6578
/* ************************************************************************ */
6679
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
6780
: Base(CollectKeys({}, discreteKeys)),
68-
isDiscrete_(true),
81+
category_(Category::Discrete),
6982
discreteKeys_(discreteKeys),
7083
continuousKeys_({}) {}
7184

7285
/* ************************************************************************ */
7386
bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
7487
const This *e = dynamic_cast<const This *>(&lf);
75-
return e != nullptr && Base::equals(*e, tol) &&
76-
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
77-
isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ &&
88+
return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ &&
89+
continuousKeys_ == e->continuousKeys_ &&
7890
discreteKeys_ == e->discreteKeys_;
7991
}
8092

8193
/* ************************************************************************ */
8294
void HybridFactor::print(const std::string &s,
8395
const KeyFormatter &formatter) const {
8496
std::cout << (s.empty() ? "" : s + "\n");
85-
if (isContinuous_) std::cout << "Continuous ";
86-
if (isDiscrete_) std::cout << "Discrete ";
87-
if (isHybrid_) std::cout << "Hybrid ";
97+
switch (category_) {
98+
case Category::Continuous:
99+
std::cout << "Continuous ";
100+
break;
101+
case Category::Discrete:
102+
std::cout << "Discrete ";
103+
break;
104+
case Category::Hybrid:
105+
std::cout << "Hybrid ";
106+
break;
107+
case Category::None:
108+
std::cout << "None ";
109+
break;
110+
}
111+
88112
std::cout << "[";
89113
for (size_t c = 0; c < continuousKeys_.size(); c++) {
90114
std::cout << formatter(continuousKeys_.at(c));

gtsam/hybrid/HybridFactor.h

+10-9
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
5252
* @ingroup hybrid
5353
*/
5454
class GTSAM_EXPORT HybridFactor : public Factor {
55+
public:
56+
/// Enum to help with categorizing hybrid factors.
57+
enum class Category { None, Discrete, Continuous, Hybrid };
58+
5559
private:
56-
bool isDiscrete_ = false;
57-
bool isContinuous_ = false;
58-
bool isHybrid_ = false;
60+
/// Record what category of HybridFactor this is.
61+
Category category_ = Category::None;
5962

6063
protected:
6164
// Set of DiscreteKeys for this factor.
@@ -116,13 +119,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
116119
/// @{
117120

118121
/// True if this is a factor of discrete variables only.
119-
bool isDiscrete() const { return isDiscrete_; }
122+
bool isDiscrete() const { return category_ == Category::Discrete; }
120123

121124
/// True if this is a factor of continuous variables only.
122-
bool isContinuous() const { return isContinuous_; }
125+
bool isContinuous() const { return category_ == Category::Continuous; }
123126

124127
/// True is this is a Discrete-Continuous factor.
125-
bool isHybrid() const { return isHybrid_; }
128+
bool isHybrid() const { return category_ == Category::Hybrid; }
126129

127130
/// Return the number of continuous variables in this factor.
128131
size_t nrContinuous() const { return continuousKeys_.size(); }
@@ -142,9 +145,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
142145
template <class ARCHIVE>
143146
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
144147
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
145-
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
146-
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
147-
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
148+
ar &BOOST_SERIALIZATION_NVP(category_);
148149
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
149150
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
150151
}

gtsam/hybrid/HybridGaussianConditional.cpp

+3-12
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,14 @@ HybridGaussianConditional::conditionals() const {
5555
return conditionals_;
5656
}
5757

58-
/* *******************************************************************************/
59-
HybridGaussianConditional::HybridGaussianConditional(
60-
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
61-
DiscreteKeys &&discreteParents,
62-
std::vector<GaussianConditional::shared_ptr> &&conditionals)
63-
: HybridGaussianConditional(continuousFrontals, continuousParents,
64-
discreteParents,
65-
Conditionals(discreteParents, conditionals)) {}
66-
6758
/* *******************************************************************************/
6859
HybridGaussianConditional::HybridGaussianConditional(
6960
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
70-
const DiscreteKeys &discreteParents,
61+
const DiscreteKey &discreteParent,
7162
const std::vector<GaussianConditional::shared_ptr> &conditionals)
7263
: HybridGaussianConditional(continuousFrontals, continuousParents,
73-
discreteParents,
74-
Conditionals(discreteParents, conditionals)) {}
64+
DiscreteKeys{discreteParent},
65+
Conditionals({discreteParent}, conditionals)) {}
7566

7667
/* *******************************************************************************/
7768
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be

gtsam/hybrid/HybridGaussianConditional.h

+6-17
Original file line numberDiff line numberDiff line change
@@ -107,29 +107,18 @@ class GTSAM_EXPORT HybridGaussianConditional
107107
const Conditionals &conditionals);
108108

109109
/**
110-
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
110+
* @brief Make a Gaussian Mixture from a vector of Gaussian conditionals.
111+
* The DecisionTree-based constructor is preferred over this one.
111112
*
112113
* @param continuousFrontals The continuous frontal variables
113114
* @param continuousParents The continuous parent variables
114-
* @param discreteParents Discrete parents variables
115-
* @param conditionals List of conditionals
116-
*/
117-
HybridGaussianConditional(
118-
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
119-
DiscreteKeys &&discreteParents,
120-
std::vector<GaussianConditional::shared_ptr> &&conditionals);
121-
122-
/**
123-
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
124-
*
125-
* @param continuousFrontals The continuous frontal variables
126-
* @param continuousParents The continuous parent variables
127-
* @param discreteParents Discrete parents variables
128-
* @param conditionals List of conditionals
115+
* @param discreteParent Single discrete parent variable
116+
* @param conditionals Vector of conditionals with the same size as the
117+
* cardinality of the discrete parent.
129118
*/
130119
HybridGaussianConditional(
131120
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
132-
const DiscreteKeys &discreteParents,
121+
const DiscreteKey &discreteParent,
133122
const std::vector<GaussianConditional::shared_ptr> &conditionals);
134123

135124
/// @}

gtsam/hybrid/HybridGaussianFactor.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ HybridGaussianFactor::Factors augment(
4242
const HybridGaussianFactor::FactorValuePairs &factors) {
4343
// Find the minimum value so we can "proselytize" to positive values.
4444
// Done because we can't have sqrt of negative numbers.
45-
auto unzipped_pair = unzip(factors);
46-
const HybridGaussianFactor::Factors gaussianFactors = unzipped_pair.first;
47-
const AlgebraicDecisionTree<Key> valueTree = unzipped_pair.second;
45+
HybridGaussianFactor::Factors gaussianFactors;
46+
AlgebraicDecisionTree<Key> valueTree;
47+
std::tie(gaussianFactors, valueTree) = unzip(factors);
48+
49+
// Normalize
4850
double min_value = valueTree.min();
4951
AlgebraicDecisionTree<Key> values =
5052
valueTree.apply([&min_value](double n) { return n - min_value; });

gtsam/hybrid/HybridGaussianFactor.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
9696
* GaussianFactor shared pointers.
9797
*
9898
* @param continuousKeys Vector of keys for continuous factors.
99-
* @param discreteKeys Vector of discrete keys.
99+
* @param discreteKey The discrete key to index each component.
100100
* @param factors Vector of gaussian factor shared pointers
101-
* and arbitrary scalars.
101+
* and arbitrary scalars. Same size as the cardinality of discreteKey.
102102
*/
103103
HybridGaussianFactor(const KeyVector &continuousKeys,
104-
const DiscreteKeys &discreteKeys,
104+
const DiscreteKey &discreteKey,
105105
const std::vector<GaussianFactorValuePair> &factors)
106-
: HybridGaussianFactor(continuousKeys, discreteKeys,
107-
FactorValuePairs(discreteKeys, factors)) {}
106+
: HybridGaussianFactor(continuousKeys, {discreteKey},
107+
FactorValuePairs({discreteKey}, factors)) {}
108108

109109
/// @}
110110
/// @name Testable

gtsam/hybrid/HybridGaussianFactorGraph.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,11 @@ void HybridGaussianFactorGraph::printErrors(
114114
<< "\n";
115115
} else {
116116
// Is hybrid
117-
auto mixtureComponent =
117+
auto conditionalComponent =
118118
hc->asMixture()->operator()(values.discrete());
119-
mixtureComponent->print(ss.str(), keyFormatter);
120-
std::cout << "error = " << mixtureComponent->error(values) << "\n";
119+
conditionalComponent->print(ss.str(), keyFormatter);
120+
std::cout << "error = " << conditionalComponent->error(values)
121+
<< "\n";
121122
}
122123
}
123124
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
@@ -411,10 +412,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
411412
// Create the HybridGaussianConditional from the conditionals
412413
HybridGaussianConditional::Conditionals conditionals(
413414
eliminationResults, [](const Result &pair) { return pair.first; });
414-
auto gaussianMixture = std::make_shared<HybridGaussianConditional>(
415+
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
415416
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
416417

417-
return {std::make_shared<HybridConditional>(gaussianMixture), newFactor};
418+
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
418419
}
419420

420421
/* ************************************************************************
@@ -465,7 +466,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
465466
// Now we will need to know how to retrieve the corresponding continuous
466467
// densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO
467468
// defined order!). We also need to consider when there is pruning. Two
468-
// mixture factors could have different pruning patterns - one could have
469+
// hybrid factors could have different pruning patterns - one could have
469470
// (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this
470471
// creates a big problem in how to identify the intersection of non-pruned
471472
// branches.

gtsam/hybrid/HybridNonlinearFactor.h

+7-6
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,18 @@ class HybridNonlinearFactor : public HybridFactor {
9292
* @tparam FACTOR The type of the factor shared pointers being passed in.
9393
* Will be typecast to NonlinearFactor shared pointers.
9494
* @param keys Vector of keys for continuous factors.
95-
* @param discreteKeys Vector of discrete keys.
95+
* @param discreteKey The discrete key indexing each component factor.
9696
* @param factors Vector of nonlinear factor and scalar pairs.
97+
* Same size as the cardinality of discreteKey.
9798
* @param normalized Flag indicating if the factor error is already
9899
* normalized.
99100
*/
100101
template <typename FACTOR>
101102
HybridNonlinearFactor(
102-
const KeyVector& keys, const DiscreteKeys& discreteKeys,
103+
const KeyVector& keys, const DiscreteKey& discreteKey,
103104
const std::vector<std::pair<std::shared_ptr<FACTOR>, double>>& factors,
104105
bool normalized = false)
105-
: Base(keys, discreteKeys), normalized_(normalized) {
106+
: Base(keys, {discreteKey}), normalized_(normalized) {
106107
std::vector<NonlinearFactorValuePair> nonlinear_factors;
107108
KeySet continuous_keys_set(keys.begin(), keys.end());
108109
KeySet factor_keys_set;
@@ -118,7 +119,7 @@ class HybridNonlinearFactor : public HybridFactor {
118119
"Factors passed into HybridNonlinearFactor need to be nonlinear!");
119120
}
120121
}
121-
factors_ = Factors(discreteKeys, nonlinear_factors);
122+
factors_ = Factors({discreteKey}, nonlinear_factors);
122123

123124
if (continuous_keys_set != factor_keys_set) {
124125
throw std::runtime_error(
@@ -140,7 +141,7 @@ class HybridNonlinearFactor : public HybridFactor {
140141
auto errorFunc =
141142
[continuousValues](const std::pair<sharedFactor, double>& f) {
142143
auto [factor, val] = f;
143-
return factor->error(continuousValues) + (0.5 * val * val);
144+
return factor->error(continuousValues) + (0.5 * val);
144145
};
145146
DecisionTree<Key, double> result(factors_, errorFunc);
146147
return result;
@@ -159,7 +160,7 @@ class HybridNonlinearFactor : public HybridFactor {
159160
auto [factor, val] = factors_(discreteValues);
160161
// Compute the error for the selected factor
161162
const double factorError = factor->error(continuousValues);
162-
return factorError + (0.5 * val * val);
163+
return factorError + (0.5 * val);
163164
}
164165

165166
/**

gtsam/hybrid/hybrid.i

+8-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ virtual class HybridConditional {
7676
class HybridGaussianFactor : gtsam::HybridFactor {
7777
HybridGaussianFactor(
7878
const gtsam::KeyVector& continuousKeys,
79-
const gtsam::DiscreteKeys& discreteKeys,
79+
const gtsam::DiscreteKey& discreteKey,
8080
const std::vector<std::pair<gtsam::GaussianFactor::shared_ptr, double>>&
8181
factorsList);
8282

@@ -91,8 +91,12 @@ class HybridGaussianConditional : gtsam::HybridFactor {
9191
const gtsam::KeyVector& continuousFrontals,
9292
const gtsam::KeyVector& continuousParents,
9393
const gtsam::DiscreteKeys& discreteParents,
94-
const std::vector<gtsam::GaussianConditional::shared_ptr>&
95-
conditionalsList);
94+
const gtsam::HybridGaussianConditional::Conditionals& conditionals);
95+
HybridGaussianConditional(
96+
const gtsam::KeyVector& continuousFrontals,
97+
const gtsam::KeyVector& continuousParents,
98+
const gtsam::DiscreteKey& discreteParent,
99+
const std::vector<gtsam::GaussianConditional::shared_ptr>& conditionals);
96100

97101
gtsam::HybridGaussianFactor* likelihood(
98102
const gtsam::VectorValues& frontals) const;
@@ -248,7 +252,7 @@ class HybridNonlinearFactor : gtsam::HybridFactor {
248252
bool normalized = false);
249253

250254
HybridNonlinearFactor(
251-
const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
255+
const gtsam::KeyVector& keys, const gtsam::DiscreteKey& discreteKey,
252256
const std::vector<std::pair<gtsam::NonlinearFactor*, double>>& factors,
253257
bool normalized = false);
254258

0 commit comments

Comments
 (0)