Skip to content

Commit d4923db

Browse files
committed
Use DecisionTree for constructing HybridGaussianConditional
1 parent 0913528 commit d4923db

9 files changed

+76
-83
lines changed

gtsam/hybrid/HybridGaussianConditional.cpp

-18
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,6 @@ HybridGaussianConditional::conditionals() const {
5555
return conditionals_;
5656
}
5757

58-
/* *******************************************************************************/
59-
HybridGaussianConditional::HybridGaussianConditional(
60-
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
61-
DiscreteKeys &&discreteParents,
62-
std::vector<GaussianConditional::shared_ptr> &&conditionals)
63-
: HybridGaussianConditional(continuousFrontals, continuousParents,
64-
discreteParents,
65-
Conditionals(discreteParents, conditionals)) {}
66-
67-
/* *******************************************************************************/
68-
HybridGaussianConditional::HybridGaussianConditional(
69-
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
70-
const DiscreteKeys &discreteParents,
71-
const std::vector<GaussianConditional::shared_ptr> &conditionals)
72-
: HybridGaussianConditional(continuousFrontals, continuousParents,
73-
discreteParents,
74-
Conditionals(discreteParents, conditionals)) {}
75-
7658
/* *******************************************************************************/
7759
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be
7860
// derived from HybridGaussianFactor, no?

gtsam/hybrid/HybridGaussianConditional.h

+1-27
Original file line numberDiff line numberDiff line change
@@ -106,32 +106,6 @@ class GTSAM_EXPORT HybridGaussianConditional
106106
const DiscreteKeys &discreteParents,
107107
const Conditionals &conditionals);
108108

109-
/**
110-
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
111-
*
112-
* @param continuousFrontals The continuous frontal variables
113-
* @param continuousParents The continuous parent variables
114-
* @param discreteParents Discrete parents variables
115-
* @param conditionals List of conditionals
116-
*/
117-
HybridGaussianConditional(
118-
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
119-
DiscreteKeys &&discreteParents,
120-
std::vector<GaussianConditional::shared_ptr> &&conditionals);
121-
122-
/**
123-
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
124-
*
125-
* @param continuousFrontals The continuous frontal variables
126-
* @param continuousParents The continuous parent variables
127-
* @param discreteParents Discrete parents variables
128-
* @param conditionals List of conditionals
129-
*/
130-
HybridGaussianConditional(
131-
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
132-
const DiscreteKeys &discreteParents,
133-
const std::vector<GaussianConditional::shared_ptr> &conditionals);
134-
135109
/// @}
136110
/// @name Testable
137111
/// @{
@@ -273,7 +247,7 @@ class GTSAM_EXPORT HybridGaussianConditional
273247
#endif
274248
};
275249

276-
/// Return the DiscreteKey vector as a set.
250+
/// Return the DiscreteKeys vector as a set.
277251
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
278252

279253
// traits

gtsam/hybrid/tests/TinyHybridExample.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ inline HybridBayesNet createHybridBayesNet(size_t num_measurements = 1,
4343
// Create Gaussian mixture z_i = x0 + noise for each measurement.
4444
for (size_t i = 0; i < num_measurements; i++) {
4545
const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode;
46+
DiscreteKeys modes{mode_i};
47+
std::vector<GaussianConditional::shared_ptr> conditionals{
48+
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5),
49+
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3)};
4650
bayesNet.emplace_shared<HybridGaussianConditional>(
4751
KeyVector{Z(i)}, KeyVector{X(0)}, DiscreteKeys{mode_i},
48-
std::vector{GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0),
49-
Z_1x1, 0.5),
50-
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0),
51-
Z_1x1, 3)});
52+
HybridGaussianConditional::Conditionals(modes, conditionals));
5253
}
5354

5455
// Create prior on X(0).

gtsam/hybrid/tests/testHybridBayesNet.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,11 @@ TEST(HybridBayesNet, evaluateHybrid) {
107107
// Create hybrid Bayes net.
108108
HybridBayesNet bayesNet;
109109
bayesNet.push_back(continuousConditional);
110+
DiscreteKeys discreteParents{Asia};
110111
bayesNet.emplace_shared<HybridGaussianConditional>(
111-
KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia},
112-
std::vector{conditional0, conditional1});
112+
KeyVector{X(1)}, KeyVector{}, discreteParents,
113+
HybridGaussianConditional::Conditionals(
114+
discreteParents, std::vector{conditional0, conditional1}));
113115
bayesNet.emplace_shared<DiscreteConditional>(Asia, "99/1");
114116

115117
// Create values at which to evaluate.
@@ -168,9 +170,11 @@ TEST(HybridBayesNet, Error) {
168170
conditional1 = std::make_shared<GaussianConditional>(
169171
X(1), Vector1::Constant(2), I_1x1, model1);
170172

173+
DiscreteKeys discreteParents{Asia};
171174
auto gm = std::make_shared<HybridGaussianConditional>(
172-
KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia},
173-
std::vector{conditional0, conditional1});
175+
KeyVector{X(1)}, KeyVector{}, discreteParents,
176+
HybridGaussianConditional::Conditionals(
177+
discreteParents, std::vector{conditional0, conditional1}));
174178
// Create hybrid Bayes net.
175179
HybridBayesNet bayesNet;
176180
bayesNet.push_back(continuousConditional);

gtsam/hybrid/tests/testHybridEstimation.cpp

+16-8
Original file line numberDiff line numberDiff line change
@@ -620,12 +620,16 @@ TEST(HybridEstimation, ModeSelection) {
620620
GaussianConditional::sharedMeanAndStddev(Z(0), -I_1x1, X(0), Z_1x1, 0.1));
621621
bn.push_back(
622622
GaussianConditional::sharedMeanAndStddev(Z(0), -I_1x1, X(1), Z_1x1, 0.1));
623+
624+
std::vector<GaussianConditional::shared_ptr> conditionals{
625+
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), -I_1x1, X(1),
626+
Z_1x1, noise_loose),
627+
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), -I_1x1, X(1),
628+
Z_1x1, noise_tight)};
623629
bn.emplace_shared<HybridGaussianConditional>(
624630
KeyVector{Z(0)}, KeyVector{X(0), X(1)}, DiscreteKeys{mode},
625-
std::vector{GaussianConditional::sharedMeanAndStddev(
626-
Z(0), I_1x1, X(0), -I_1x1, X(1), Z_1x1, noise_loose),
627-
GaussianConditional::sharedMeanAndStddev(
628-
Z(0), I_1x1, X(0), -I_1x1, X(1), Z_1x1, noise_tight)});
631+
HybridGaussianConditional::Conditionals(DiscreteKeys{mode},
632+
conditionals));
629633

630634
VectorValues vv;
631635
vv.insert(Z(0), Z_1x1);
@@ -651,12 +655,16 @@ TEST(HybridEstimation, ModeSelection2) {
651655
GaussianConditional::sharedMeanAndStddev(Z(0), -I_3x3, X(0), Z_3x1, 0.1));
652656
bn.push_back(
653657
GaussianConditional::sharedMeanAndStddev(Z(0), -I_3x3, X(1), Z_3x1, 0.1));
658+
659+
std::vector<GaussianConditional::shared_ptr> conditionals{
660+
GaussianConditional::sharedMeanAndStddev(Z(0), I_3x3, X(0), -I_3x3, X(1),
661+
Z_3x1, noise_loose),
662+
GaussianConditional::sharedMeanAndStddev(Z(0), I_3x3, X(0), -I_3x3, X(1),
663+
Z_3x1, noise_tight)};
654664
bn.emplace_shared<HybridGaussianConditional>(
655665
KeyVector{Z(0)}, KeyVector{X(0), X(1)}, DiscreteKeys{mode},
656-
std::vector{GaussianConditional::sharedMeanAndStddev(
657-
Z(0), I_3x3, X(0), -I_3x3, X(1), Z_3x1, noise_loose),
658-
GaussianConditional::sharedMeanAndStddev(
659-
Z(0), I_3x3, X(0), -I_3x3, X(1), Z_3x1, noise_tight)});
666+
HybridGaussianConditional::Conditionals(DiscreteKeys{mode},
667+
conditionals));
660668

661669
VectorValues vv;
662670
vv.insert(Z(0), Z_3x1);

gtsam/hybrid/tests/testHybridGaussianConditional.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ const std::vector<GaussianConditional::shared_ptr> conditionals{
5252
commonSigma),
5353
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0),
5454
commonSigma)};
55-
const HybridGaussianConditional mixture({Z(0)}, {X(0)}, {mode}, conditionals);
55+
const HybridGaussianConditional mixture(
56+
{Z(0)}, {X(0)}, {mode},
57+
HybridGaussianConditional::Conditionals({mode}, conditionals));
5658
} // namespace equal_constants
5759

5860
/* ************************************************************************* */
@@ -153,7 +155,9 @@ const std::vector<GaussianConditional::shared_ptr> conditionals{
153155
0.5),
154156
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0),
155157
3.0)};
156-
const HybridGaussianConditional mixture({Z(0)}, {X(0)}, {mode}, conditionals);
158+
const HybridGaussianConditional mixture(
159+
{Z(0)}, {X(0)}, {mode},
160+
HybridGaussianConditional::Conditionals({mode}, conditionals));
157161
} // namespace mode_dependent_constants
158162

159163
/* ************************************************************************* */

gtsam/hybrid/tests/testHybridGaussianFactor.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,11 @@ static HybridBayesNet GetGaussianMixtureModel(double mu0, double mu1,
233233
c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1);
234234

235235
HybridBayesNet hbn;
236+
DiscreteKeys discreteParents{m};
236237
hbn.emplace_shared<HybridGaussianConditional>(
237-
KeyVector{z}, KeyVector{}, DiscreteKeys{m}, std::vector{c0, c1});
238+
KeyVector{z}, KeyVector{}, discreteParents,
239+
HybridGaussianConditional::Conditionals(discreteParents,
240+
std::vector{c0, c1}));
238241

239242
auto mixing = make_shared<DiscreteConditional>(m, "50/50");
240243
hbn.push_back(mixing);
@@ -408,8 +411,11 @@ static HybridGaussianConditional::shared_ptr CreateHybridMotionModel(
408411
-I_1x1, model0),
409412
c1 = make_shared<GaussianConditional>(X(1), Vector1(mu1), I_1x1, X(0),
410413
-I_1x1, model1);
414+
DiscreteKeys discreteParents{m1};
411415
return std::make_shared<HybridGaussianConditional>(
412-
KeyVector{X(1)}, KeyVector{X(0)}, DiscreteKeys{m1}, std::vector{c0, c1});
416+
KeyVector{X(1)}, KeyVector{X(0)}, discreteParents,
417+
HybridGaussianConditional::Conditionals(discreteParents,
418+
std::vector{c0, c1}));
413419
}
414420

415421
/// Create two state Bayes network with 1 or two measurement models

gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp

+30-17
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,11 @@ TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) {
682682
x0, -I_1x1, model0),
683683
c1 = make_shared<GaussianConditional>(f01, Vector1(mu), I_1x1, x1, I_1x1,
684684
x0, -I_1x1, model1);
685+
DiscreteKeys discreteParents{m1};
685686
hbn.emplace_shared<HybridGaussianConditional>(
686-
KeyVector{f01}, KeyVector{x0, x1}, DiscreteKeys{m1}, std::vector{c0, c1});
687+
KeyVector{f01}, KeyVector{x0, x1}, discreteParents,
688+
HybridGaussianConditional::Conditionals(discreteParents,
689+
std::vector{c0, c1}));
687690

688691
// Discrete uniform prior.
689692
hbn.emplace_shared<DiscreteConditional>(m1, "0.5/0.5");
@@ -806,9 +809,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
806809
X(0), Vector1(14.1421), I_1x1 * 2.82843),
807810
conditional1 = std::make_shared<GaussianConditional>(
808811
X(0), Vector1(10.1379), I_1x1 * 2.02759);
812+
DiscreteKeys discreteParents{mode};
809813
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
810-
KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode},
811-
std::vector{conditional0, conditional1});
814+
KeyVector{X(0)}, KeyVector{}, discreteParents,
815+
HybridGaussianConditional::Conditionals(
816+
discreteParents, std::vector{conditional0, conditional1}));
812817

813818
// Add prior on mode.
814819
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "74/26");
@@ -831,12 +836,13 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
831836
HybridBayesNet bn;
832837

833838
// Create Gaussian mixture z_0 = x0 + noise for each measurement.
839+
std::vector<GaussianConditional::shared_ptr> conditionals{
840+
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3),
841+
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 0.5)};
834842
auto gm = std::make_shared<HybridGaussianConditional>(
835843
KeyVector{Z(0)}, KeyVector{X(0)}, DiscreteKeys{mode},
836-
std::vector{
837-
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3),
838-
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1,
839-
0.5)});
844+
HybridGaussianConditional::Conditionals(DiscreteKeys{mode},
845+
conditionals));
840846
bn.push_back(gm);
841847

842848
// Create prior on X(0).
@@ -865,7 +871,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
865871
X(0), Vector1(14.1421), I_1x1 * 2.82843);
866872
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
867873
KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode},
868-
std::vector{conditional0, conditional1});
874+
HybridGaussianConditional::Conditionals(
875+
DiscreteKeys{mode}, std::vector{conditional0, conditional1}));
869876

870877
// Add prior on mode.
871878
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "1/1");
@@ -902,7 +909,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
902909
X(0), Vector1(10.274), I_1x1 * 2.0548);
903910
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
904911
KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode},
905-
std::vector{conditional0, conditional1});
912+
HybridGaussianConditional::Conditionals(
913+
DiscreteKeys{mode}, std::vector{conditional0, conditional1}));
906914

907915
// Add prior on mode.
908916
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "23/77");
@@ -947,12 +955,14 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
947955
for (size_t t : {0, 1, 2}) {
948956
// Create Gaussian mixture on Z(t) conditioned on X(t) and mode N(t):
949957
const auto noise_mode_t = DiscreteKey{N(t), 2};
958+
std::vector<GaussianConditional::shared_ptr> conditionals{
959+
GaussianConditional::sharedMeanAndStddev(Z(t), I_1x1, X(t), Z_1x1, 0.5),
960+
GaussianConditional::sharedMeanAndStddev(Z(t), I_1x1, X(t), Z_1x1,
961+
3.0)};
950962
bn.emplace_shared<HybridGaussianConditional>(
951963
KeyVector{Z(t)}, KeyVector{X(t)}, DiscreteKeys{noise_mode_t},
952-
std::vector{GaussianConditional::sharedMeanAndStddev(Z(t), I_1x1, X(t),
953-
Z_1x1, 0.5),
954-
GaussianConditional::sharedMeanAndStddev(Z(t), I_1x1, X(t),
955-
Z_1x1, 3.0)});
964+
HybridGaussianConditional::Conditionals(DiscreteKeys{noise_mode_t},
965+
conditionals));
956966

957967
// Create prior on discrete mode N(t):
958968
bn.emplace_shared<DiscreteConditional>(noise_mode_t, "20/80");
@@ -962,12 +972,15 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
962972
for (size_t t : {2, 1}) {
963973
// Create Gaussian mixture on X(t) conditioned on X(t-1) and mode M(t-1):
964974
const auto motion_model_t = DiscreteKey{M(t), 2};
975+
std::vector<GaussianConditional::shared_ptr> conditionals{
976+
GaussianConditional::sharedMeanAndStddev(X(t), I_1x1, X(t - 1), Z_1x1,
977+
0.2),
978+
GaussianConditional::sharedMeanAndStddev(X(t), I_1x1, X(t - 1), I_1x1,
979+
0.2)};
965980
auto gm = std::make_shared<HybridGaussianConditional>(
966981
KeyVector{X(t)}, KeyVector{X(t - 1)}, DiscreteKeys{motion_model_t},
967-
std::vector{GaussianConditional::sharedMeanAndStddev(
968-
X(t), I_1x1, X(t - 1), Z_1x1, 0.2),
969-
GaussianConditional::sharedMeanAndStddev(
970-
X(t), I_1x1, X(t - 1), I_1x1, 0.2)});
982+
HybridGaussianConditional::Conditionals(DiscreteKeys{motion_model_t},
983+
conditionals));
971984
bn.push_back(gm);
972985

973986
// Create prior on motion model M(t):

gtsam/hybrid/tests/testSerializationHybrid.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ TEST(HybridSerialization, HybridGaussianConditional) {
116116
const auto conditional1 = std::make_shared<GaussianConditional>(
117117
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
118118
const HybridGaussianConditional gm({Z(0)}, {X(0)}, {mode},
119-
{conditional0, conditional1});
119+
HybridGaussianConditional::Conditionals(
120+
{mode}, {conditional0, conditional1}));
120121

121122
EXPECT(equalsObj<HybridGaussianConditional>(gm));
122123
EXPECT(equalsXML<HybridGaussianConditional>(gm));

0 commit comments

Comments
 (0)