Skip to content

Commit 231d1ad

Browse files
authored
Merge pull request #1696 from borglab/model-selection-integration
AlgebraicDecisionTree Helpers
2 parents 67976ad + 73d971a commit 231d1ad

7 files changed

+105
-16
lines changed

gtsam/discrete/AlgebraicDecisionTree.h

+36
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,42 @@ namespace gtsam {
196196
return this->apply(g, &Ring::div);
197197
}
198198

199+
/// Compute sum of all values
200+
double sum() const {
201+
double sum = 0;
202+
auto visitor = [&](double y) { sum += y; };
203+
this->visit(visitor);
204+
return sum;
205+
}
206+
207+
/**
208+
* @brief Helper method to perform normalization such that all leaves in the
209+
* tree sum to 1
210+
*
211+
* @param sum
212+
* @return AlgebraicDecisionTree
213+
*/
214+
AlgebraicDecisionTree normalize(double sum) const {
215+
return this->apply([&sum](const double& x) { return x / sum; });
216+
}
217+
218+
/// Find the minimum values amongst all leaves
219+
double min() const {
220+
double min = std::numeric_limits<double>::max();
221+
auto visitor = [&](double x) { min = x < min ? x : min; };
222+
this->visit(visitor);
223+
return min;
224+
}
225+
226+
/// Find the maximum values amongst all leaves
227+
double max() const {
228+
// Get the most negative value
229+
double max = -std::numeric_limits<double>::max();
230+
auto visitor = [&](double x) { max = x > max ? x : max; };
231+
this->visit(visitor);
232+
return max;
233+
}
234+
199235
/** sum out variable */
200236
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
201237
return this->combine(label, cardinality, &Ring::add);

gtsam/discrete/tests/testAlgebraicDecisionTree.cpp

+49
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,55 @@ TEST(ADT, zero) {
593593
EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9);
594594
}
595595

596+
/// Example ADT from 0 to 11.
597+
ADT exampleADT() {
598+
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
599+
ADT f(A & B & C, "0 6 2 8 4 10 1 7 3 9 5 11");
600+
return f;
601+
}
602+
/* ************************************************************************** */
603+
// Test sum
604+
TEST(ADT, Sum) {
605+
ADT a = exampleADT();
606+
double expected_sum = 0;
607+
for (double i = 0; i < 12; i++) {
608+
expected_sum += i;
609+
}
610+
EXPECT_DOUBLES_EQUAL(expected_sum, a.sum(), 1e-9);
611+
}
612+
613+
/* ************************************************************************** */
614+
// Test normalize
615+
TEST(ADT, Normalize) {
616+
ADT a = exampleADT();
617+
double sum = a.sum();
618+
auto actual = a.normalize(sum);
619+
620+
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
621+
DiscreteKeys keys = DiscreteKeys{A, B, C};
622+
std::vector<double> cpt{0 / sum, 6 / sum, 2 / sum, 8 / sum,
623+
4 / sum, 10 / sum, 1 / sum, 7 / sum,
624+
3 / sum, 9 / sum, 5 / sum, 11 / sum};
625+
ADT expected(keys, cpt);
626+
EXPECT(assert_equal(expected, actual));
627+
}
628+
629+
/* ************************************************************************** */
630+
// Test min
631+
TEST(ADT, Min) {
632+
ADT a = exampleADT();
633+
double min = a.min();
634+
EXPECT_DOUBLES_EQUAL(0.0, min, 1e-9);
635+
}
636+
637+
/* ************************************************************************** */
638+
// Test max
639+
TEST(ADT, Max) {
640+
ADT a = exampleADT();
641+
double max = a.max();
642+
EXPECT_DOUBLES_EQUAL(11.0, max, 1e-9);
643+
}
644+
596645
/* ************************************************************************* */
597646
int main() {
598647
TestResult tr;

gtsam/hybrid/GaussianMixture.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ class GTSAM_EXPORT GaussianMixture
6767
double logConstant_; ///< log of the normalization constant.
6868

6969
/**
70-
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
70+
* @brief Convert a DecisionTree of factors into
71+
* a DecisionTree of Gaussian factor graphs.
7172
*/
7273
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
7374

@@ -214,7 +215,8 @@ class GTSAM_EXPORT GaussianMixture
214215
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
215216
* only, with the leaf values as the error for each assignment.
216217
*/
217-
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
218+
AlgebraicDecisionTree<Key> errorTree(
219+
const VectorValues &continuousValues) const;
218220

219221
/**
220222
* @brief Compute the logProbability of this Gaussian Mixture.

gtsam/hybrid/GaussianMixtureFactor.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
135135
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
136136
* as the factors involved, and leaf values as the error.
137137
*/
138-
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
138+
AlgebraicDecisionTree<Key> errorTree(
139+
const VectorValues &continuousValues) const;
139140

140141
/**
141142
* @brief Compute the log-likelihood, including the log-normalizing constant.

gtsam/hybrid/HybridValues.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
/**
1313
* @file HybridValues.h
1414
* @date Jul 28, 2022
15+
* @author Varun Agrawal
1516
* @author Shangjie Xue
1617
*/
1718

@@ -54,13 +55,13 @@ class GTSAM_EXPORT HybridValues {
5455
HybridValues() = default;
5556

5657
/// Construct from DiscreteValues and VectorValues.
57-
HybridValues(const VectorValues &cv, const DiscreteValues &dv)
58-
: continuous_(cv), discrete_(dv){}
58+
HybridValues(const VectorValues& cv, const DiscreteValues& dv)
59+
: continuous_(cv), discrete_(dv) {}
5960

6061
/// Construct from all values types.
6162
HybridValues(const VectorValues& cv, const DiscreteValues& dv,
6263
const Values& v)
63-
: continuous_(cv), discrete_(dv), nonlinear_(v){}
64+
: continuous_(cv), discrete_(dv), nonlinear_(v) {}
6465

6566
/// @}
6667
/// @name Testable
@@ -101,9 +102,7 @@ class GTSAM_EXPORT HybridValues {
101102
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }
102103

103104
/// Check whether a variable with key \c j exists in values.
104-
bool existsNonlinear(Key j) {
105-
return nonlinear_.exists(j);
106-
}
105+
bool existsNonlinear(Key j) { return nonlinear_.exists(j); }
107106

108107
/// Check whether a variable with key \c j exists.
109108
bool exists(Key j) {
@@ -128,9 +127,7 @@ class GTSAM_EXPORT HybridValues {
128127
}
129128

130129
/// insert_or_assign() , similar to Values.h
131-
void insert_or_assign(Key j, size_t value) {
132-
discrete_[j] = value;
133-
}
130+
void insert_or_assign(Key j, size_t value) { discrete_[j] = value; }
134131

135132
/** Insert all continuous values from \c values. Throws an invalid_argument
136133
* exception if any keys to be inserted are already used. */

gtsam/hybrid/tests/testHybridEstimation.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ TEST(HybridEstimation, Probability) {
333333
for (auto discrete_conditional : *discreteBayesNet) {
334334
bayesNet->add(discrete_conditional);
335335
}
336-
auto discreteConditional = discreteBayesNet->at(0)->asDiscrete();
337336

338337
HybridValues hybrid_values = bayesNet->optimize();
339338

gtsam/nonlinear/nonlinear.i

+8-3
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,22 @@ typedef gtsam::GncOptimizer<gtsam::GncParams<gtsam::LevenbergMarquardtParams>> G
381381
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
382382
virtual class LevenbergMarquardtOptimizer : gtsam::NonlinearOptimizer {
383383
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
384-
const gtsam::Values& initialValues);
384+
const gtsam::Values& initialValues,
385+
const gtsam::LevenbergMarquardtParams& params =
386+
gtsam::LevenbergMarquardtParams());
385387
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
386388
const gtsam::Values& initialValues,
387-
const gtsam::LevenbergMarquardtParams& params);
389+
const gtsam::Ordering& ordering,
390+
const gtsam::LevenbergMarquardtParams& params =
391+
gtsam::LevenbergMarquardtParams());
392+
388393
double lambda() const;
389394
void print(string s = "") const;
390395
};
391396

392397
#include <gtsam/nonlinear/ISAM2.h>
393398
class ISAM2GaussNewtonParams {
394-
ISAM2GaussNewtonParams();
399+
ISAM2GaussNewtonParams(double _wildfireThreshold = 0.001);
395400

396401
void print(string s = "") const;
397402

0 commit comments

Comments
 (0)