Skip to content

Commit 78c7a6b

Browse files
committed
Change optimize_using to simpler function call
1 parent 1e03c8b commit 78c7a6b

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

python/gtsam/tests/test_logging_optimizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def hook(_, error):
6666

6767
# Wrapper function sets the hook and calls optimizer.optimize() for us.
6868
params = gtsam.GaussNewtonParams()
69-
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial)
69+
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial)
7070
self.check(actual)
71-
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params)
71+
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial, params)
7272
self.check(actual)
7373
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
7474
params, hook)
@@ -80,10 +80,10 @@ def hook(_, error):
8080
print(error)
8181

8282
params = gtsam.LevenbergMarquardtParams()
83-
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial)
83+
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial)
8484
self.check(actual)
85-
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial,
86-
params)
85+
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial,
86+
params)
8787
self.check(actual)
8888
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
8989
params, hook)

python/gtsam/utils/logging_optimizer.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,39 +17,42 @@
1717
}
1818

1919

20-
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]:
20+
def optimize_using(OptimizerClass, hook, *args) -> gtsam.Values:
2121
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
2222
hook.
2323
Example usage:
24-
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params)
24+
```python
25+
def hook(optimizer, error):
26+
print("iteration {:}, error = {:}".format(optimizer.iterations(), error))
27+
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook, graph, init, params)
28+
```
29+
Iteration hook's args are (optimizer, error) and return type should be None
2530
2631
Args:
2732
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
28-
LevenbergMarquadrtOptimizer)
33+
LevenbergMarquardtOptimizer)
2934
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
3035
error) and return should be None.
36+
*args: Arguments that would be passed into the OptimizerClass constructor, usually:
37+
graph, init, [params]
3138
Returns:
32-
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer
33-
arguments (will be forwarded to constructor) and it will return a Values object
34-
representing the solution. See example usage above.
39+
(gtsam.Values): A Values object representing the optimization solution.
3540
"""
36-
37-
def wrapped_optimize(*args):
38-
for arg in args:
39-
if isinstance(arg, gtsam.NonlinearOptimizerParams):
40-
arg.iterationHook = lambda iteration, error_before, error_after: hook(
41-
optimizer, error_after)
42-
break
43-
else:
44-
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
45-
params.iterationHook = lambda iteration, error_before, error_after: hook(
41+
# Add the iteration hook to the NonlinearOptimizerParams
42+
for arg in args:
43+
if isinstance(arg, gtsam.NonlinearOptimizerParams):
44+
arg.iterationHook = lambda iteration, error_before, error_after: hook(
4645
optimizer, error_after)
47-
args = (*args, params)
48-
optimizer = OptimizerClass(*args)
49-
hook(optimizer, optimizer.error())
50-
return optimizer.optimize()
51-
52-
return wrapped_optimize
46+
break
47+
else:
48+
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
49+
params.iterationHook = lambda iteration, error_before, error_after: hook(
50+
optimizer, error_after)
51+
args = (*args, params)
52+
# Construct Optimizer and optimize
53+
optimizer = OptimizerClass(*args)
54+
hook(optimizer, optimizer.error()) # Call hook once with init values to match behavior below
55+
return optimizer.optimize()
5356

5457

5558
def optimize(optimizer, check_convergence, hook):

0 commit comments

Comments
 (0)