Skip to content

Commit feab2a2

Browse files
authored
Merge pull request #1781 from borglab/discrete-improv
2 parents cc2e8de + 4d62b87 commit feab2a2

7 files changed

+63
-25
lines changed

gtsam/discrete/DiscreteBayesNet.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include <gtsam/discrete/DiscreteBayesNet.h>
2020
#include <gtsam/discrete/DiscreteConditional.h>
21+
#include <gtsam/discrete/DiscreteFactorGraph.h>
22+
#include <gtsam/discrete/DiscreteLookupDAG.h>
2123
#include <gtsam/inference/FactorGraph-inst.h>
2224

2325
namespace gtsam {
@@ -56,7 +58,8 @@ DiscreteValues DiscreteBayesNet::sample() const {
5658

5759
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
5860
// sample each node in turn in topological sort order (parents first)
59-
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) {
61+
for (auto it = std::make_reverse_iterator(end());
62+
it != std::make_reverse_iterator(begin()); ++it) {
6063
(*it)->sampleInPlace(&result);
6164
}
6265
return result;

gtsam/discrete/DiscreteConditional.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,19 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
235235
}
236236

237237
/* ************************************************************************** */
238-
size_t DiscreteConditional::argmax() const {
238+
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
239+
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
240+
241+
// Initialize
239242
size_t maxValue = 0;
240243
double maxP = 0;
244+
DiscreteValues values = parentsValues;
245+
241246
assert(nrFrontals() == 1);
242-
assert(nrParents() == 0);
243-
DiscreteValues frontals;
244247
Key j = firstFrontalKey();
245248
for (size_t value = 0; value < cardinality(j); value++) {
246-
frontals[j] = value;
247-
double pValueS = (*this)(frontals);
249+
values[j] = value;
250+
double pValueS = (*this)(values);
248251
// Update MPE solution if better
249252
if (pValueS > maxP) {
250253
maxP = pValueS;
@@ -459,7 +462,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
459462
}
460463

461464
/* ************************************************************************* */
462-
double DiscreteConditional::evaluate(const HybridValues& x) const{
465+
double DiscreteConditional::evaluate(const HybridValues& x) const {
463466
return this->evaluate(x.discrete());
464467
}
465468
/* ************************************************************************* */

gtsam/discrete/DiscreteConditional.h

+7-9
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
#pragma once
2020

21-
#include <gtsam/inference/Conditional-inst.h>
2221
#include <gtsam/discrete/DecisionTreeFactor.h>
2322
#include <gtsam/discrete/Signature.h>
23+
#include <gtsam/inference/Conditional-inst.h>
2424

2525
#include <memory>
2626
#include <string>
@@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional
3939
public Conditional<DecisionTreeFactor, DiscreteConditional> {
4040
public:
4141
// typedefs needed to play nice with gtsam
42-
typedef DiscreteConditional This; ///< Typedef to this class
42+
typedef DiscreteConditional This; ///< Typedef to this class
4343
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
4444
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
4545
typedef Conditional<BaseFactor, This>
@@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional
159159
/// @{
160160

161161
/// Log-probability is just -error(x).
162-
double logProbability(const DiscreteValues& x) const {
163-
return -error(x);
164-
}
162+
double logProbability(const DiscreteValues& x) const { return -error(x); }
165163

166164
/// print index signature only
167165
void printSignature(
@@ -214,10 +212,11 @@ class GTSAM_EXPORT DiscreteConditional
214212
size_t sample() const;
215213

216214
/**
217-
* @brief Return assignment that maximizes distribution.
218-
* @return Optimal assignment (1 frontal variable).
215+
* @brief Return assignment for single frontal variable that maximizes value.
216+
* @param parentsValues Known assignments for the parents.
217+
* @return maximizing assignment for the frontal variable.
219218
*/
220-
size_t argmax() const;
219+
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
221220

222221
/// @}
223222
/// @name Advanced Interface
@@ -244,7 +243,6 @@ class GTSAM_EXPORT DiscreteConditional
244243
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
245244
const Names& names = {}) const override;
246245

247-
248246
/// @}
249247
/// @name HybridValues methods.
250248
/// @{

gtsam/discrete/DiscreteLookupDAG.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
119119

120120
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
121121
// Argmax each node in turn in topological sort order (parents first).
122-
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) {
122+
for (auto it = std::make_reverse_iterator(end());
123+
it != std::make_reverse_iterator(begin()); ++it) {
123124
// dereference to get the sharedFactor to the lookup table
124125
(*it)->argmaxInPlace(&result);
125126
}

gtsam/discrete/discrete.i

+7-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ class DiscreteKeys {
1414
bool empty() const;
1515
gtsam::DiscreteKey at(size_t n) const;
1616
void push_back(const gtsam::DiscreteKey& point_pair);
17+
void print(const std::string& s = "",
18+
const gtsam::KeyFormatter& keyFormatter =
19+
gtsam::DefaultKeyFormatter) const;
1720
};
1821

1922
// DiscreteValues is added in specializations/discrete.h as a std::map
@@ -104,6 +107,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
104107
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
105108
const gtsam::DecisionTreeFactor& marginal,
106109
const gtsam::Ordering& orderedKeys);
110+
DiscreteConditional(const gtsam::DiscreteKey& key,
111+
const gtsam::DiscreteKeys& parents,
112+
const std::vector<double>& table);
107113

108114
// Standard interface
109115
double logNormalizationConstant() const;
@@ -131,6 +137,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
131137
size_t sample(size_t value) const;
132138
size_t sample() const;
133139
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
140+
size_t argmax(const gtsam::DiscreteValues& parents) const;
134141

135142
// Markdown and HTML
136143
string markdown(const gtsam::KeyFormatter& keyFormatter =
@@ -159,7 +166,6 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
159166
gtsam::DefaultKeyFormatter) const;
160167
double operator()(size_t value) const;
161168
std::vector<double> pmf() const;
162-
size_t argmax() const;
163169
};
164170

165171
#include <gtsam/discrete/DiscreteBayesNet.h>

gtsam/discrete/tests/testDiscreteBayesNet.cpp

+5-7
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
* @author Frank Dellaert
1717
*/
1818

19+
#include <CppUnitLite/TestHarness.h>
20+
#include <gtsam/base/Testable.h>
21+
#include <gtsam/base/Vector.h>
22+
#include <gtsam/base/debug.h>
1923
#include <gtsam/discrete/DiscreteBayesNet.h>
2024
#include <gtsam/discrete/DiscreteFactorGraph.h>
2125
#include <gtsam/discrete/DiscreteMarginals.h>
22-
#include <gtsam/base/debug.h>
23-
#include <gtsam/base/Testable.h>
24-
#include <gtsam/base/Vector.h>
25-
26-
#include <CppUnitLite/TestHarness.h>
2726

2827
#include <iostream>
2928
#include <string>
@@ -43,8 +42,7 @@ TEST(DiscreteBayesNet, bayesNet) {
4342
DiscreteKey Parent(0, 2), Child(1, 2);
4443

4544
auto prior = std::make_shared<DiscreteConditional>(Parent % "6/4");
46-
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
47-
(ADT)*prior));
45+
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), (ADT)*prior));
4846
bayesNet.push_back(prior);
4947

5048
auto conditional =

gtsam/discrete/tests/testDiscreteConditional.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,35 @@ TEST(DiscreteConditional, choose) {
289289
EXPECT(assert_equal(expected3, *actual3, 1e-9));
290290
}
291291

292+
/* ************************************************************************* */
293+
// Check argmax on P(C|D) and P(D), plus tie-breaking for P(B)
294+
TEST(DiscreteConditional, Argmax) {
295+
DiscreteKey B(2, 2), C(2, 2), D(4, 2);
296+
DiscreteConditional B_prior(D, "1/1");
297+
DiscreteConditional D_prior(D, "1/3");
298+
DiscreteConditional C_given_D((C | D) = "1/4 1/1");
299+
300+
// Case 1: Tie breaking
301+
size_t actual1 = B_prior.argmax();
302+
// In the case of ties, the first value is chosen.
303+
EXPECT_LONGS_EQUAL(0, actual1);
304+
// Case 2: No parents
305+
size_t actual2 = D_prior.argmax();
306+
// Selects 1 since it has 0.75 probability
307+
EXPECT_LONGS_EQUAL(1, actual2);
308+
309+
// Case 3: Given parent values
310+
DiscreteValues given;
311+
given[D.first] = 1;
312+
size_t actual3 = C_given_D.argmax(given);
313+
// Should be 0 since D=1 gives 0.5/0.5
314+
EXPECT_LONGS_EQUAL(0, actual3);
315+
316+
given[D.first] = 0;
317+
size_t actual4 = C_given_D.argmax(given);
318+
EXPECT_LONGS_EQUAL(1, actual4);
319+
}
320+
292321
/* ************************************************************************* */
293322
// Check markdown representation looks as expected, no parents.
294323
TEST(DiscreteConditional, markdown_prior) {

0 commit comments

Comments
 (0)