Skip to content

Commit de68aec

Browse files
committed
fix tests
1 parent 336b494 commit de68aec

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

gtsam/hybrid/tests/testHybridGaussianFactor.cpp

+25-27
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,9 @@ double prob_m_z(double mu0, double mu1, double sigma0, double sigma1,
221221
return 1 / (1 + e);
222222
};
223223

224-
static HybridBayesNet GetGaussianMixtureModel(double mu0, double mu1,
225-
double sigma0, double sigma1) {
224+
static HybridBayesNet GetHybridGaussianConditionalModel(double mu0, double mu1,
225+
double sigma0,
226+
double sigma1) {
226227
DiscreteKey m(M(0), 2);
227228
Key z = Z(0);
228229

@@ -254,7 +255,7 @@ static HybridBayesNet GetGaussianMixtureModel(double mu0, double mu1,
254255
* The resulting factor graph should eliminate to a Bayes net
255256
* which represents a sigmoid function.
256257
*/
257-
TEST(HybridGaussianFactor, GaussianMixtureModel) {
258+
TEST(HybridGaussianFactor, HybridGaussianConditionalModel) {
258259
using namespace test_gmm;
259260

260261
double mu0 = 1.0, mu1 = 3.0;
@@ -263,7 +264,7 @@ TEST(HybridGaussianFactor, GaussianMixtureModel) {
263264
DiscreteKey m(M(0), 2);
264265
Key z = Z(0);
265266

266-
auto hbn = GetGaussianMixtureModel(mu0, mu1, sigma, sigma);
267+
auto hbn = GetHybridGaussianConditionalModel(mu0, mu1, sigma, sigma);
267268

268269
// The result should be a sigmoid.
269270
// So should be P(m=1|z) = 0.5 at z=3.0 - 1.0=2.0
@@ -326,7 +327,7 @@ TEST(HybridGaussianFactor, GaussianMixtureModel) {
326327
* which represents a Gaussian-like function
327328
* where m1>m0 close to 3.1333.
328329
*/
329-
TEST(HybridGaussianFactor, GaussianMixtureModel2) {
330+
TEST(HybridGaussianFactor, HybridGaussianConditionalModel2) {
330331
using namespace test_gmm;
331332

332333
double mu0 = 1.0, mu1 = 3.0;
@@ -335,7 +336,7 @@ TEST(HybridGaussianFactor, GaussianMixtureModel2) {
335336
DiscreteKey m(M(0), 2);
336337
Key z = Z(0);
337338

338-
auto hbn = GetGaussianMixtureModel(mu0, mu1, sigma0, sigma1);
339+
auto hbn = GetHybridGaussianConditionalModel(mu0, mu1, sigma0, sigma1);
339340

340341
double m1_high = 3.133, lambda = 4;
341342
{
@@ -393,7 +394,7 @@ namespace test_two_state_estimation {
393394

394395
DiscreteKey m1(M(1), 2);
395396

396-
void addMeasurement(HybridBayesNet& hbn, Key z_key, Key x_key, double sigma) {
397+
void addMeasurement(HybridBayesNet &hbn, Key z_key, Key x_key, double sigma) {
397398
auto measurement_model = noiseModel::Isotropic::Sigma(1, sigma);
398399
hbn.emplace_shared<GaussianConditional>(z_key, Vector1(0.0), I_1x1, x_key,
399400
-I_1x1, measurement_model);
@@ -414,7 +415,7 @@ static HybridGaussianConditional::shared_ptr CreateHybridMotionModel(
414415

415416
/// Create two state Bayes network with 1 or two measurement models
416417
HybridBayesNet CreateBayesNet(
417-
const HybridGaussianConditional::shared_ptr& hybridMotionModel,
418+
const HybridGaussianConditional::shared_ptr &hybridMotionModel,
418419
bool add_second_measurement = false) {
419420
HybridBayesNet hbn;
420421

@@ -437,9 +438,9 @@ HybridBayesNet CreateBayesNet(
437438

438439
/// Approximate the discrete marginal P(m1) using importance sampling
439440
std::pair<double, double> approximateDiscreteMarginal(
440-
const HybridBayesNet& hbn,
441-
const HybridGaussianConditional::shared_ptr& hybridMotionModel,
442-
const VectorValues& given, size_t N = 100000) {
441+
const HybridBayesNet &hbn,
442+
const HybridGaussianConditional::shared_ptr &hybridMotionModel,
443+
const VectorValues &given, size_t N = 100000) {
443444
/// Create importance sampling network q(x0,x1,m) = p(x1|x0,m1) q(x0) P(m1),
444445
/// using q(x0) = N(z0, sigmaQ) to sample x0.
445446
HybridBayesNet q;
@@ -758,7 +759,7 @@ static HybridGaussianFactorGraph GetFactorGraphFromBayesNet(
758759
auto model1 = noiseModel::Isotropic::Sigma(1, sigmas[1]);
759760
auto prior_noise = noiseModel::Isotropic::Sigma(1, 1e-3);
760761

761-
// GaussianMixtureFactor component factors
762+
// HybridGaussianFactor component factors
762763
auto f0 =
763764
std::make_shared<BetweenFactor<double>>(X(0), X(1), means[0], model0);
764765
auto f1 =
@@ -783,8 +784,8 @@ static HybridGaussianFactorGraph GetFactorGraphFromBayesNet(
783784
std::make_shared<GaussianConditional>(terms0, 1, -d0, model0),
784785
std::make_shared<GaussianConditional>(terms1, 1, -d1, model1)};
785786
gtsam::HybridBayesNet bn;
786-
bn.emplace_shared<GaussianMixture>(KeyVector{Z(1)}, KeyVector{X(0), X(1)},
787-
DiscreteKeys{m1}, conditionals);
787+
bn.emplace_shared<HybridGaussianConditional>(
788+
KeyVector{Z(1)}, KeyVector{X(0), X(1)}, DiscreteKeys{m1}, conditionals);
788789

789790
// Create FG via toFactorGraph
790791
gtsam::VectorValues measurements;
@@ -812,7 +813,7 @@ static HybridGaussianFactorGraph GetFactorGraphFromBayesNet(
812813
* P(Z1 | X1, X2, M1) has 2 conditionals each for the binary
813814
* mode m1.
814815
*/
815-
TEST(GaussianMixtureFactor, FactorGraphFromBayesNet) {
816+
TEST(HybridGaussianFactor, FactorGraphFromBayesNet) {
816817
DiscreteKey m1(M(1), 2);
817818

818819
Values values;
@@ -889,8 +890,8 @@ namespace test_direct_factor_graph {
889890
* then perform linearization.
890891
*
891892
* @param values Initial values to linearize around.
892-
* @param means The means of the GaussianMixtureFactor components.
893-
* @param sigmas The covariances of the GaussianMixtureFactor components.
893+
* @param means The means of the HybridGaussianFactor components.
894+
* @param sigmas The covariances of the HybridGaussianFactor components.
894895
* @param m1 The discrete key.
895896
* @return HybridGaussianFactorGraph
896897
*/
@@ -908,13 +909,10 @@ static HybridGaussianFactorGraph CreateFactorGraph(
908909
std::make_shared<BetweenFactor<double>>(X(0), X(1), means[1], model1)
909910
->linearize(values);
910911

911-
// Create GaussianMixtureFactor
912-
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
913-
AlgebraicDecisionTree<Key> logNormalizers(
914-
{m1}, std::vector<double>{ComputeLogNormalizer(model0),
915-
ComputeLogNormalizer(model1)});
916-
GaussianMixtureFactor mixtureFactor({X(0), X(1)}, {m1}, factors,
917-
logNormalizers);
912+
// Create HybridGaussianFactor
913+
std::vector<GaussianFactorValuePair> factors{
914+
{f0, ComputeLogNormalizer(model0)}, {f1, ComputeLogNormalizer(model1)}};
915+
HybridGaussianFactor mixtureFactor({X(0), X(1)}, {m1}, factors);
918916

919917
HybridGaussianFactorGraph hfg;
920918
hfg.push_back(mixtureFactor);
@@ -934,7 +932,7 @@ static HybridGaussianFactorGraph CreateFactorGraph(
934932
* |
935933
* M1
936934
*/
937-
TEST(GaussianMixtureFactor, DifferentMeansFG) {
935+
TEST(HybridGaussianFactor, DifferentMeansFG) {
938936
using namespace test_direct_factor_graph;
939937

940938
DiscreteKey m1(M(1), 2);
@@ -1009,7 +1007,7 @@ TEST(GaussianMixtureFactor, DifferentMeansFG) {
10091007
* |
10101008
* M1
10111009
*/
1012-
TEST(GaussianMixtureFactor, DifferentCovariancesFG) {
1010+
TEST(HybridGaussianFactor, DifferentCovariancesFG) {
10131011
using namespace test_direct_factor_graph;
10141012

10151013
DiscreteKey m1(M(1), 2);
@@ -1021,7 +1019,7 @@ TEST(GaussianMixtureFactor, DifferentCovariancesFG) {
10211019

10221020
std::vector<double> means = {0.0, 0.0}, sigmas = {1e2, 1e-2};
10231021

1024-
// Create FG with GaussianMixtureFactor and prior on X1
1022+
// Create FG with HybridGaussianFactor and prior on X1
10251023
HybridGaussianFactorGraph mixture_fg =
10261024
CreateFactorGraph(values, means, sigmas, m1);
10271025

0 commit comments

Comments
 (0)