Skip to content

Commit 3a7a0b8

Browse files
committed
use enum to categorize HybridFactor
1 parent 1c74da2 commit 3a7a0b8

File tree

3 files changed

+42
-25
lines changed

3 files changed

+42
-25
lines changed

gtsam/hybrid/HybridFactor.cpp

+27-12
Original file line numberDiff line numberDiff line change
@@ -50,41 +50,56 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
5050

5151
/* ************************************************************************ */
5252
HybridFactor::HybridFactor(const KeyVector &keys)
53-
: Base(keys), isContinuous_(true), continuousKeys_(keys) {}
53+
: Base(keys),
54+
category_(HybridCategory::Continuous),
55+
continuousKeys_(keys) {}
5456

5557
/* ************************************************************************ */
5658
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
5759
const DiscreteKeys &discreteKeys)
5860
: Base(CollectKeys(continuousKeys, discreteKeys)),
59-
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
60-
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
61-
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
6261
discreteKeys_(discreteKeys),
63-
continuousKeys_(continuousKeys) {}
62+
continuousKeys_(continuousKeys) {
63+
if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) {
64+
category_ = HybridCategory::Discrete;
65+
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) {
66+
category_ = HybridCategory::Continuous;
67+
} else {
68+
category_ = HybridCategory::Hybrid;
69+
}
70+
}
6471

6572
/* ************************************************************************ */
6673
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
6774
: Base(CollectKeys({}, discreteKeys)),
68-
isDiscrete_(true),
75+
category_(HybridCategory::Discrete),
6976
discreteKeys_(discreteKeys),
7077
continuousKeys_({}) {}
7178

7279
/* ************************************************************************ */
7380
bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
7481
const This *e = dynamic_cast<const This *>(&lf);
75-
return e != nullptr && Base::equals(*e, tol) &&
76-
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
77-
isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ &&
82+
return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ &&
83+
continuousKeys_ == e->continuousKeys_ &&
7884
discreteKeys_ == e->discreteKeys_;
7985
}
8086

8187
/* ************************************************************************ */
8288
void HybridFactor::print(const std::string &s,
8389
const KeyFormatter &formatter) const {
8490
std::cout << (s.empty() ? "" : s + "\n");
85-
if (isContinuous_) std::cout << "Continuous ";
86-
if (isDiscrete_) std::cout << "Discrete ";
87-
if (isHybrid_) std::cout << "Hybrid ";
91+
switch (category_) {
92+
case HybridCategory::Continuous:
93+
std::cout << "Continuous ";
94+
break;
95+
case HybridCategory::Discrete:
96+
std::cout << "Discrete ";
97+
break;
98+
case HybridCategory::Hybrid:
99+
std::cout << "Hybrid ";
100+
break;
101+
}
102+
88103
std::cout << "[";
89104
for (size_t c = 0; c < continuousKeys_.size(); c++) {
90105
std::cout << formatter(continuousKeys_.at(c));

gtsam/hybrid/HybridFactor.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
4141
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
4242
const DiscreteKeys &key2);
4343

44+
/// Enum to help with categorizing hybrid factors.
45+
enum class HybridCategory { Discrete, Continuous, Hybrid };
46+
4447
/**
4548
* Base class for *truly* hybrid probabilistic factors
4649
*
@@ -53,9 +56,8 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
5356
*/
5457
class GTSAM_EXPORT HybridFactor : public Factor {
5558
private:
56-
bool isDiscrete_ = false;
57-
bool isContinuous_ = false;
58-
bool isHybrid_ = false;
59+
/// Record what category of HybridFactor this is.
60+
HybridCategory category_;
5961

6062
protected:
6163
// Set of DiscreteKeys for this factor.
@@ -116,13 +118,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
116118
/// @{
117119

118120
/// True if this is a factor of discrete variables only.
119-
bool isDiscrete() const { return isDiscrete_; }
121+
bool isDiscrete() const { return category_ == HybridCategory::Discrete; }
120122

121123
/// True if this is a factor of continuous variables only.
122-
bool isContinuous() const { return isContinuous_; }
124+
bool isContinuous() const { return category_ == HybridCategory::Continuous; }
123125

124126
/// True is this is a Discrete-Continuous factor.
125-
bool isHybrid() const { return isHybrid_; }
127+
bool isHybrid() const { return category_ == HybridCategory::Hybrid; }
126128

127129
/// Return the number of continuous variables in this factor.
128130
size_t nrContinuous() const { return continuousKeys_.size(); }
@@ -142,9 +144,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
142144
template <class ARCHIVE>
143145
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
144146
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
145-
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
146-
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
147-
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
147+
ar &BOOST_SERIALIZATION_NVP(category_);
148148
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
149149
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
150150
}

gtsam/hybrid/tests/testHybridBayesNet.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,13 @@ TEST(HybridBayesNet, Sampling) {
387387
std::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
388388
auto one_motion =
389389
std::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
390-
std::vector<NonlinearFactorValuePair> factors = {{zero_motion, 0.0},
391-
{one_motion, 0.0}};
390+
391+
DiscreteKeys discreteKeys{DiscreteKey(M(0), 2)};
392+
HybridNonlinearFactor::Factors factors(
393+
discreteKeys, {{zero_motion, 0.0}, {one_motion, 0.0}});
392394
nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model);
393-
nfg.emplace_shared<HybridNonlinearFactor>(
394-
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);
395+
nfg.emplace_shared<HybridNonlinearFactor>(KeyVector{X(0), X(1)}, discreteKeys,
396+
factors);
395397

396398
DiscreteKey mode(M(0), 2);
397399
nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");

0 commit comments

Comments
 (0)