Skip to content

Commit 06bc781

Browse files
authored
Merge pull request #1127 from borglab/feature/python/iterationHook
2 parents 953aa9f + 78c7a6b commit 06bc781

File tree

4 files changed

+143
-65
lines changed

4 files changed

+143
-65
lines changed

gtsam/nonlinear/nonlinear.i

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,9 @@ virtual class NonlinearOptimizerParams {
484484
bool isSequential() const;
485485
bool isCholmod() const;
486486
bool isIterative() const;
487+
488+
// This only applies to python since matlab does not have lambda machinery.
489+
gtsam::NonlinearOptimizerParams::IterationHook iterationHook;
487490
};
488491

489492
bool checkConvergence(double relativeErrorTreshold,

python/gtsam/tests/test_NonlinearOptimizer.py

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,76 +15,94 @@
1515
import unittest
1616

1717
import gtsam
18-
from gtsam import (DoglegOptimizer, DoglegParams,
19-
DummyPreconditionerParameters, GaussNewtonOptimizer,
20-
GaussNewtonParams, GncLMParams, GncLMOptimizer,
21-
LevenbergMarquardtOptimizer, LevenbergMarquardtParams,
22-
NonlinearFactorGraph, Ordering,
23-
PCGSolverParameters, Point2, PriorFactorPoint2, Values)
18+
from gtsam import (DoglegOptimizer, DoglegParams, DummyPreconditionerParameters,
19+
GaussNewtonOptimizer, GaussNewtonParams, GncLMParams, GncLMOptimizer,
20+
LevenbergMarquardtOptimizer, LevenbergMarquardtParams, NonlinearFactorGraph,
21+
Ordering, PCGSolverParameters, Point2, PriorFactorPoint2, Values)
2422
from gtsam.utils.test_case import GtsamTestCase
2523

2624
KEY1 = 1
2725
KEY2 = 2
2826

2927

3028
class TestScenario(GtsamTestCase):
31-
def test_optimize(self):
32-
"""Do trivial test with three optimizer variants."""
33-
fg = NonlinearFactorGraph()
29+
"""Do trivial test with three optimizer variants."""
30+
31+
def setUp(self):
32+
"""Set up the optimization problem and ordering"""
33+
# create graph
34+
self.fg = NonlinearFactorGraph()
3435
model = gtsam.noiseModel.Unit.Create(2)
35-
fg.add(PriorFactorPoint2(KEY1, Point2(0, 0), model))
36+
self.fg.add(PriorFactorPoint2(KEY1, Point2(0, 0), model))
3637

3738
# test error at minimum
3839
xstar = Point2(0, 0)
39-
optimal_values = Values()
40-
optimal_values.insert(KEY1, xstar)
41-
self.assertEqual(0.0, fg.error(optimal_values), 0.0)
40+
self.optimal_values = Values()
41+
self.optimal_values.insert(KEY1, xstar)
42+
self.assertEqual(0.0, self.fg.error(self.optimal_values), 0.0)
4243

4344
# test error at initial = [(1-cos(3))^2 + (sin(3))^2]*50 =
4445
x0 = Point2(3, 3)
45-
initial_values = Values()
46-
initial_values.insert(KEY1, x0)
47-
self.assertEqual(9.0, fg.error(initial_values), 1e-3)
46+
self.initial_values = Values()
47+
self.initial_values.insert(KEY1, x0)
48+
self.assertEqual(9.0, self.fg.error(self.initial_values), 1e-3)
4849

4950
# optimize parameters
50-
ordering = Ordering()
51-
ordering.push_back(KEY1)
51+
self.ordering = Ordering()
52+
self.ordering.push_back(KEY1)
5253

53-
# Gauss-Newton
54+
def test_gauss_newton(self):
5455
gnParams = GaussNewtonParams()
55-
gnParams.setOrdering(ordering)
56-
actual1 = GaussNewtonOptimizer(fg, initial_values, gnParams).optimize()
57-
self.assertAlmostEqual(0, fg.error(actual1))
56+
gnParams.setOrdering(self.ordering)
57+
actual = GaussNewtonOptimizer(self.fg, self.initial_values, gnParams).optimize()
58+
self.assertAlmostEqual(0, self.fg.error(actual))
5859

59-
# Levenberg-Marquardt
60+
def test_levenberg_marquardt(self):
6061
lmParams = LevenbergMarquardtParams.CeresDefaults()
61-
lmParams.setOrdering(ordering)
62-
actual2 = LevenbergMarquardtOptimizer(
63-
fg, initial_values, lmParams).optimize()
64-
self.assertAlmostEqual(0, fg.error(actual2))
62+
lmParams.setOrdering(self.ordering)
63+
actual = LevenbergMarquardtOptimizer(self.fg, self.initial_values, lmParams).optimize()
64+
self.assertAlmostEqual(0, self.fg.error(actual))
6565

66-
# Levenberg-Marquardt
66+
def test_levenberg_marquardt_pcg(self):
6767
lmParams = LevenbergMarquardtParams.CeresDefaults()
6868
lmParams.setLinearSolverType("ITERATIVE")
6969
cgParams = PCGSolverParameters()
7070
cgParams.setPreconditionerParams(DummyPreconditionerParameters())
7171
lmParams.setIterativeParams(cgParams)
72-
actual2 = LevenbergMarquardtOptimizer(
73-
fg, initial_values, lmParams).optimize()
74-
self.assertAlmostEqual(0, fg.error(actual2))
72+
actual = LevenbergMarquardtOptimizer(self.fg, self.initial_values, lmParams).optimize()
73+
self.assertAlmostEqual(0, self.fg.error(actual))
7574

76-
# Dogleg
75+
def test_dogleg(self):
7776
dlParams = DoglegParams()
78-
dlParams.setOrdering(ordering)
79-
actual3 = DoglegOptimizer(fg, initial_values, dlParams).optimize()
80-
self.assertAlmostEqual(0, fg.error(actual3))
81-
82-
# Graduated Non-Convexity (GNC)
83-
gncParams = GncLMParams()
84-
actual4 = GncLMOptimizer(fg, initial_values, gncParams).optimize()
85-
self.assertAlmostEqual(0, fg.error(actual4))
86-
77+
dlParams.setOrdering(self.ordering)
78+
actual = DoglegOptimizer(self.fg, self.initial_values, dlParams).optimize()
79+
self.assertAlmostEqual(0, self.fg.error(actual))
8780

81+
def test_graduated_non_convexity(self):
82+
gncParams = GncLMParams()
83+
actual = GncLMOptimizer(self.fg, self.initial_values, gncParams).optimize()
84+
self.assertAlmostEqual(0, self.fg.error(actual))
85+
86+
def test_iteration_hook(self):
87+
# set up iteration hook to track some testable values
88+
iteration_count = 0
89+
final_error = 0
90+
final_values = None
91+
def iteration_hook(iter, error_before, error_after):
92+
nonlocal iteration_count, final_error, final_values
93+
iteration_count = iter
94+
final_error = error_after
95+
final_values = optimizer.values()
96+
# optimize
97+
params = LevenbergMarquardtParams.CeresDefaults()
98+
params.setOrdering(self.ordering)
99+
params.iterationHook = iteration_hook
100+
optimizer = LevenbergMarquardtOptimizer(self.fg, self.initial_values, params)
101+
actual = optimizer.optimize()
102+
self.assertAlmostEqual(0, self.fg.error(actual))
103+
self.gtsamAssertEquals(final_values, actual)
104+
self.assertEqual(self.fg.error(actual), final_error)
105+
self.assertEqual(optimizer.iterations(), iteration_count)
88106

89107
if __name__ == "__main__":
90108
unittest.main()

python/gtsam/tests/test_logging_optimizer.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from gtsam import Rot3
1919
from gtsam.utils.test_case import GtsamTestCase
2020

21-
from gtsam.utils.logging_optimizer import gtsam_optimize
21+
from gtsam.utils.logging_optimizer import gtsam_optimize, optimize_using
2222

2323
KEY = 0
2424
MODEL = gtsam.noiseModel.Unit.Create(3)
@@ -34,19 +34,20 @@ def setUp(self):
3434
rotations = {R, R.inverse()} # mean is the identity
3535
self.expected = Rot3()
3636

37-
graph = gtsam.NonlinearFactorGraph()
37+
def check(actual):
38+
# Check that optimizing yields the identity
39+
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
40+
# Check that logging output prints out 3 lines (exact intermediate values differ by OS)
41+
self.assertEqual(self.capturedOutput.getvalue().count('\n'), 3)
42+
# reset stdout catcher
43+
self.capturedOutput.truncate(0)
44+
self.check = check
45+
46+
self.graph = gtsam.NonlinearFactorGraph()
3847
for R in rotations:
39-
graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
40-
initial = gtsam.Values()
41-
initial.insert(KEY, R)
42-
self.params = gtsam.GaussNewtonParams()
43-
self.optimizer = gtsam.GaussNewtonOptimizer(
44-
graph, initial, self.params)
45-
46-
self.lmparams = gtsam.LevenbergMarquardtParams()
47-
self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer(
48-
graph, initial, self.lmparams
49-
)
48+
self.graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
49+
self.initial = gtsam.Values()
50+
self.initial.insert(KEY, R)
5051

5152
# setup output capture
5253
self.capturedOutput = StringIO()
@@ -63,22 +64,29 @@ def test_simple_printing(self):
6364
def hook(_, error):
6465
print(error)
6566

66-
# Only thing we require from optimizer is an iterate method
67-
gtsam_optimize(self.optimizer, self.params, hook)
68-
69-
# Check that optimizing yields the identity.
70-
actual = self.optimizer.values()
71-
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
67+
# Wrapper function sets the hook and calls optimizer.optimize() for us.
68+
params = gtsam.GaussNewtonParams()
69+
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial)
70+
self.check(actual)
71+
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial, params)
72+
self.check(actual)
73+
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
74+
params, hook)
75+
self.check(actual)
7276

7377
def test_lm_simple_printing(self):
7478
"""Make sure we are properly terminating LM"""
7579
def hook(_, error):
7680
print(error)
7781

78-
gtsam_optimize(self.lmoptimizer, self.lmparams, hook)
79-
80-
actual = self.lmoptimizer.values()
81-
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
82+
params = gtsam.LevenbergMarquardtParams()
83+
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial)
84+
self.check(actual)
85+
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial,
86+
params)
87+
self.check(actual)
88+
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
89+
params, hook)
8290

8391
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
8492
def test_comet(self):

python/gtsam/utils/logging_optimizer.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,53 @@
66

77
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
88
import gtsam
9+
from typing import Any, Callable
10+
11+
OPTIMIZER_PARAMS_MAP = {
12+
gtsam.GaussNewtonOptimizer: gtsam.GaussNewtonParams,
13+
gtsam.LevenbergMarquardtOptimizer: gtsam.LevenbergMarquardtParams,
14+
gtsam.DoglegOptimizer: gtsam.DoglegParams,
15+
gtsam.GncGaussNewtonOptimizer: gtsam.GaussNewtonParams,
16+
gtsam.GncLMOptimizer: gtsam.LevenbergMarquardtParams
17+
}
18+
19+
20+
def optimize_using(OptimizerClass, hook, *args) -> gtsam.Values:
21+
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
22+
hook.
23+
Example usage:
24+
```python
25+
def hook(optimizer, error):
26+
print("iteration {:}, error = {:}".format(optimizer.iterations(), error))
27+
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook, graph, init, params)
28+
```
29+
Iteration hook's args are (optimizer, error) and return type should be None
30+
31+
Args:
32+
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
33+
LevenbergMarquardtOptimizer)
34+
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
35+
error) and return should be None.
36+
*args: Arguments that would be passed into the OptimizerClass constructor, usually:
37+
graph, init, [params]
38+
Returns:
39+
(gtsam.Values): A Values object representing the optimization solution.
40+
"""
41+
# Add the iteration hook to the NonlinearOptimizerParams
42+
for arg in args:
43+
if isinstance(arg, gtsam.NonlinearOptimizerParams):
44+
arg.iterationHook = lambda iteration, error_before, error_after: hook(
45+
optimizer, error_after)
46+
break
47+
else:
48+
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
49+
params.iterationHook = lambda iteration, error_before, error_after: hook(
50+
optimizer, error_after)
51+
args = (*args, params)
52+
# Construct Optimizer and optimize
53+
optimizer = OptimizerClass(*args)
54+
hook(optimizer, optimizer.error()) # Call hook once with init values to match behavior below
55+
return optimizer.optimize()
956

1057

1158
def optimize(optimizer, check_convergence, hook):
@@ -21,7 +68,8 @@ def optimize(optimizer, check_convergence, hook):
2168
current_error = optimizer.error()
2269
hook(optimizer, current_error)
2370

24-
# Iterative loop
71+
# Iterative loop. Cannot use `params.iterationHook` because we don't have access to params
72+
# (backwards compatibility issue).
2573
while True:
2674
# Do next iteration
2775
optimizer.iterate()
@@ -36,6 +84,7 @@ def gtsam_optimize(optimizer,
3684
params,
3785
hook):
3886
""" Given an optimizer and params, iterate until convergence.
87+
Recommend using optimize_using instead.
3988
After each iteration, hook(optimizer) is called.
4089
After the function, use values and errors to get the result.
4190
Arguments:

0 commit comments

Comments
 (0)