-
Notifications
You must be signed in to change notification settings - Fork 805
/
Copy pathHybridBayesNet.cpp
384 lines (332 loc) · 14 KB
/
HybridBayesNet.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridBayesNet.cpp
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the Gaussian mixture.
std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// This corresponds to 0 probability
double pruned_prob = 0.0;
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the Gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDiscreteProbs(values) == 0) {
return pruned_prob;
} else {
return probability;
}
} else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}
// Now we generate the full assignment by enumerating
// over all keys in the prunedDiscreteProbs.
// First we find the differing keys
std::vector<DiscreteKey> set_diff;
std::set_difference(discreteProbsKeySet.begin(),
discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment);
// If any one of the sub-branches are non-zero,
// we need this probability.
if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return pruned_prob;
}
};
return pruner;
}
/* ************************************************************************* */
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
// The joint discrete probability.
DiscreteConditional discreteProbs;
std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;
for (size_t i = 0; i < this->size(); i++) {
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
discreteProbs = discreteProbs * (*conditional->asDiscrete());
Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
}
const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
// Eliminate joint probability back into conditionals
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
return prunedDiscreteProbs;
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
* We can later check the HybridGaussianConditional for just nullptrs.
*/
HybridBayesNet prunedBayesNetFragment;
// Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
auto prunedHybridGaussianConditional =
std::make_shared<HybridGaussianConditional>(*gm);
prunedHybridGaussianConditional->prune(
prunedDiscreteProbs); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedHybridGaussianConditional);
} else {
// Add the non-HybridGaussianConditional conditional
prunedBayesNetFragment.push_back(conditional);
}
}
return prunedBayesNetFragment;
}
/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment.
gbn.push_back((*gm)(assignment));
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
gbn.push_back(gc);
} else if (auto dc = conditional->asDiscrete()) {
// If conditional is discrete-only, we simply continue.
continue;
}
}
return gbn;
}
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteFactorGraph discrete_fg;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());
}
}
// Solve for the MPE
DiscreteValues mpe = discrete_fg.optimize();
// Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe);
}
/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = choose(assignment);
// Check if there exists a nullptr in the GaussianBayesNet
// If yes, return an empty VectorValues
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
return VectorValues();
}
return gbn.optimize();
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscrete());
}
}
// Sample a discrete assignment.
const DiscreteValues assignment = dbn.sample(given.discrete());
// Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net.
VectorValues sample = gbn.sample(given.continuous(), rng);
return {sample, assignment};
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
HybridValues given;
return sample(given, rng);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->errorTree(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the result tree.
result = result.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete error in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->error(DiscreteValues(assignment));
});
}
}
return result;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute
// logProbability.
result = result + gm->logProbability(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the (double) logProbability and add it to the
// result
double logProbability = gc->logProbability(continuousValues);
// Add the computed logProbability to every leaf of the logProbability
// tree.
result = result.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete logProbability in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->logProbability(DiscreteValues(assignment));
});
}
}
return result;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
return tree.apply([](double log) { return exp(log); });
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(logProbability(values));
}
/* ************************************************************************* */
HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
const VectorValues &measurements) const {
HybridGaussianFactorGraph fg;
// For all nodes in the Bayes net, if its frontal variable is in measurements,
// replace it by a likelihood factor:
for (auto &&conditional : *this) {
if (conditional->frontalsIn(measurements)) {
if (auto gc = conditional->asGaussian()) {
fg.push_back(gc->likelihood(measurements));
} else if (auto gm = conditional->asMixture()) {
fg.push_back(gm->likelihood(measurements));
} else {
throw std::runtime_error("Unknown conditional type");
}
} else {
fg.push_back(conditional);
}
}
return fg;
}
} // namespace gtsam