Skip to content

Commit f62805f

Browse files
committed
add method to select underlying continuous Gaussian graph given discrete assignment
1 parent 8372d84 commit f62805f

3 files changed

+79
-1
lines changed

gtsam/hybrid/HybridGaussianFactorGraph.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(
296296

297297
// Logspace version of:
298298
// exp(-factor->error(kEmpty)) * conditional->normalizationConstant();
299-
// We take negative of the logNormalizationConstant `log(1/k)` to get `log(k)`.
299+
// We take negative of the logNormalizationConstant `log(1/k)`
300+
// to get `log(k)`.
300301
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
301302
};
302303

@@ -326,6 +327,7 @@ static std::shared_ptr<Factor> createGaussianMixtureFactor(
326327
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
327328
if (!hf) throw std::runtime_error("Expected HessianFactor!");
328329
// Add 2.0 term since the constant term will be premultiplied by 0.5
330+
// as per the Hessian definition
329331
hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
330332
}
331333
return factor;
@@ -563,4 +565,24 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
563565
return prob_tree;
564566
}
565567

568+
/* ************************************************************************ */
569+
GaussianFactorGraph HybridGaussianFactorGraph::operator()(
570+
const DiscreteValues &assignment) const {
571+
GaussianFactorGraph gfg;
572+
for (auto &&f : *this) {
573+
if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(f)) {
574+
gfg.push_back(gf);
575+
} else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
576+
gfg.push_back(gf);
577+
} else if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
578+
gfg.push_back((*gmf)(assignment));
579+
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
580+
gfg.push_back((*gm)(assignment));
581+
} else {
582+
continue;
583+
}
584+
}
585+
return gfg;
586+
}
587+
566588
} // namespace gtsam

gtsam/hybrid/HybridGaussianFactorGraph.h

+4
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
210210
GaussianFactorGraphTree assembleGraphTree() const;
211211

212212
/// @}
213+
214+
/// Get the GaussianFactorGraph at a given discrete assignment.
215+
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
216+
213217
};
214218

215219
} // namespace gtsam

gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,58 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
490490
}
491491
}
492492

493+
/* ****************************************************************************/
494+
// Select a particular continuous factor graph given a discrete assignment
495+
TEST(HybridGaussianFactorGraph, DiscreteSelection) {
496+
Switching s(3);
497+
498+
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
499+
500+
DiscreteValues dv00{{M(0), 0}, {M(1), 0}};
501+
GaussianFactorGraph continuous_00 = graph(dv00);
502+
GaussianFactorGraph expected_00;
503+
expected_00.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
504+
expected_00.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1)));
505+
expected_00.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1)));
506+
expected_00.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
507+
expected_00.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
508+
509+
EXPECT(assert_equal(expected_00, continuous_00));
510+
511+
DiscreteValues dv01{{M(0), 0}, {M(1), 1}};
512+
GaussianFactorGraph continuous_01 = graph(dv01);
513+
GaussianFactorGraph expected_01;
514+
expected_01.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
515+
expected_01.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1)));
516+
expected_01.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0)));
517+
expected_01.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
518+
expected_01.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
519+
520+
EXPECT(assert_equal(expected_01, continuous_01));
521+
522+
DiscreteValues dv10{{M(0), 1}, {M(1), 0}};
523+
GaussianFactorGraph continuous_10 = graph(dv10);
524+
GaussianFactorGraph expected_10;
525+
expected_10.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
526+
expected_10.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0)));
527+
expected_10.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1)));
528+
expected_10.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
529+
expected_10.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
530+
531+
EXPECT(assert_equal(expected_10, continuous_10));
532+
533+
DiscreteValues dv11{{M(0), 1}, {M(1), 1}};
534+
GaussianFactorGraph continuous_11 = graph(dv11);
535+
GaussianFactorGraph expected_11;
536+
expected_11.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
537+
expected_11.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0)));
538+
expected_11.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0)));
539+
expected_11.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
540+
expected_11.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
541+
542+
EXPECT(assert_equal(expected_11, continuous_11));
543+
}
544+
493545
/* ************************************************************************* */
494546
TEST(HybridGaussianFactorGraph, optimize) {
495547
HybridGaussianFactorGraph hfg;

0 commit comments

Comments
 (0)