|
15 | 15 | import unittest |
16 | 16 |
|
17 | 17 | import gtsam |
18 | | -from gtsam import (DoglegOptimizer, DoglegParams, |
19 | | - DummyPreconditionerParameters, GaussNewtonOptimizer, |
20 | | - GaussNewtonParams, GncLMParams, GncLMOptimizer, |
21 | | - LevenbergMarquardtOptimizer, LevenbergMarquardtParams, |
22 | | - NonlinearFactorGraph, Ordering, |
23 | | - PCGSolverParameters, Point2, PriorFactorPoint2, Values) |
| 18 | +from gtsam import (DoglegOptimizer, DoglegParams, DummyPreconditionerParameters, |
| 19 | + GaussNewtonOptimizer, GaussNewtonParams, GncLMParams, GncLMOptimizer, |
| 20 | + LevenbergMarquardtOptimizer, LevenbergMarquardtParams, NonlinearFactorGraph, |
| 21 | + Ordering, PCGSolverParameters, Point2, PriorFactorPoint2, Values) |
24 | 22 | from gtsam.utils.test_case import GtsamTestCase |
25 | 23 |
|
26 | 24 | KEY1 = 1 |
27 | 25 | KEY2 = 2 |
28 | 26 |
|
29 | 27 |
|
30 | 28 | class TestScenario(GtsamTestCase): |
31 | | - def test_optimize(self): |
32 | | - """Do trivial test with three optimizer variants.""" |
33 | | - fg = NonlinearFactorGraph() |
| 29 | + """Do trivial test with three optimizer variants.""" |
| 30 | + |
| 31 | + def setUp(self): |
| 32 | + """Set up the optimization problem and ordering""" |
| 33 | + # create graph |
| 34 | + self.fg = NonlinearFactorGraph() |
34 | 35 | model = gtsam.noiseModel.Unit.Create(2) |
35 | | - fg.add(PriorFactorPoint2(KEY1, Point2(0, 0), model)) |
| 36 | + self.fg.add(PriorFactorPoint2(KEY1, Point2(0, 0), model)) |
36 | 37 |
|
37 | 38 | # test error at minimum |
38 | 39 | xstar = Point2(0, 0) |
39 | | - optimal_values = Values() |
40 | | - optimal_values.insert(KEY1, xstar) |
41 | | - self.assertEqual(0.0, fg.error(optimal_values), 0.0) |
| 40 | + self.optimal_values = Values() |
| 41 | + self.optimal_values.insert(KEY1, xstar) |
| 42 | + self.assertEqual(0.0, self.fg.error(self.optimal_values), 0.0) |
42 | 43 |
|
43 | 44 | # test error at initial = [(1-cos(3))^2 + (sin(3))^2]*50 = |
44 | 45 | x0 = Point2(3, 3) |
45 | | - initial_values = Values() |
46 | | - initial_values.insert(KEY1, x0) |
47 | | - self.assertEqual(9.0, fg.error(initial_values), 1e-3) |
| 46 | + self.initial_values = Values() |
| 47 | + self.initial_values.insert(KEY1, x0) |
| 48 | + self.assertEqual(9.0, self.fg.error(self.initial_values), 1e-3) |
48 | 49 |
|
49 | 50 | # optimize parameters |
50 | | - ordering = Ordering() |
51 | | - ordering.push_back(KEY1) |
| 51 | + self.ordering = Ordering() |
| 52 | + self.ordering.push_back(KEY1) |
52 | 53 |
|
53 | | - # Gauss-Newton |
| 54 | + def test_gauss_newton(self): |
54 | 55 | gnParams = GaussNewtonParams() |
55 | | - gnParams.setOrdering(ordering) |
56 | | - actual1 = GaussNewtonOptimizer(fg, initial_values, gnParams).optimize() |
57 | | - self.assertAlmostEqual(0, fg.error(actual1)) |
| 56 | + gnParams.setOrdering(self.ordering) |
| 57 | + actual = GaussNewtonOptimizer(self.fg, self.initial_values, gnParams).optimize() |
| 58 | + self.assertAlmostEqual(0, self.fg.error(actual)) |
58 | 59 |
|
59 | | - # Levenberg-Marquardt |
| 60 | + def test_levenberg_marquardt(self): |
60 | 61 | lmParams = LevenbergMarquardtParams.CeresDefaults() |
61 | | - lmParams.setOrdering(ordering) |
62 | | - actual2 = LevenbergMarquardtOptimizer( |
63 | | - fg, initial_values, lmParams).optimize() |
64 | | - self.assertAlmostEqual(0, fg.error(actual2)) |
| 62 | + lmParams.setOrdering(self.ordering) |
| 63 | + actual = LevenbergMarquardtOptimizer(self.fg, self.initial_values, lmParams).optimize() |
| 64 | + self.assertAlmostEqual(0, self.fg.error(actual)) |
65 | 65 |
|
66 | | - # Levenberg-Marquardt |
| 66 | + def test_levenberg_marquardt_pcg(self): |
67 | 67 | lmParams = LevenbergMarquardtParams.CeresDefaults() |
68 | 68 | lmParams.setLinearSolverType("ITERATIVE") |
69 | 69 | cgParams = PCGSolverParameters() |
70 | 70 | cgParams.setPreconditionerParams(DummyPreconditionerParameters()) |
71 | 71 | lmParams.setIterativeParams(cgParams) |
72 | | - actual2 = LevenbergMarquardtOptimizer( |
73 | | - fg, initial_values, lmParams).optimize() |
74 | | - self.assertAlmostEqual(0, fg.error(actual2)) |
| 72 | + actual = LevenbergMarquardtOptimizer(self.fg, self.initial_values, lmParams).optimize() |
| 73 | + self.assertAlmostEqual(0, self.fg.error(actual)) |
75 | 74 |
|
76 | | - # Dogleg |
| 75 | + def test_dogleg(self): |
77 | 76 | dlParams = DoglegParams() |
78 | | - dlParams.setOrdering(ordering) |
79 | | - actual3 = DoglegOptimizer(fg, initial_values, dlParams).optimize() |
80 | | - self.assertAlmostEqual(0, fg.error(actual3)) |
81 | | - |
82 | | - # Graduated Non-Convexity (GNC) |
83 | | - gncParams = GncLMParams() |
84 | | - actual4 = GncLMOptimizer(fg, initial_values, gncParams).optimize() |
85 | | - self.assertAlmostEqual(0, fg.error(actual4)) |
86 | | - |
| 77 | + dlParams.setOrdering(self.ordering) |
| 78 | + actual = DoglegOptimizer(self.fg, self.initial_values, dlParams).optimize() |
| 79 | + self.assertAlmostEqual(0, self.fg.error(actual)) |
87 | 80 |
|
| 81 | + def test_graduated_non_convexity(self): |
| 82 | + gncParams = GncLMParams() |
| 83 | + actual = GncLMOptimizer(self.fg, self.initial_values, gncParams).optimize() |
| 84 | + self.assertAlmostEqual(0, self.fg.error(actual)) |
| 85 | + |
| 86 | + def test_iteration_hook(self): |
| 87 | + # set up iteration hook to track some testable values |
| 88 | + iteration_count = 0 |
| 89 | + final_error = 0 |
| 90 | + final_values = None |
| 91 | + def iteration_hook(iter, error_before, error_after): |
| 92 | + nonlocal iteration_count, final_error, final_values |
| 93 | + iteration_count = iter |
| 94 | + final_error = error_after |
| 95 | + final_values = optimizer.values() |
| 96 | + # optimize |
| 97 | + params = LevenbergMarquardtParams.CeresDefaults() |
| 98 | + params.setOrdering(self.ordering) |
| 99 | + params.iterationHook = iteration_hook |
| 100 | + optimizer = LevenbergMarquardtOptimizer(self.fg, self.initial_values, params) |
| 101 | + actual = optimizer.optimize() |
| 102 | + self.assertAlmostEqual(0, self.fg.error(actual)) |
| 103 | + self.gtsamAssertEquals(final_values, actual) |
| 104 | + self.assertEqual(self.fg.error(actual), final_error) |
| 105 | + self.assertEqual(optimizer.iterations(), iteration_count) |
88 | 106 |
|
89 | 107 | if __name__ == "__main__": |
90 | 108 | unittest.main() |
0 commit comments