@@ -32,27 +32,28 @@ keyword arguments for the `local_method` of a global optimizer are passed as a
32
32
33
33
Over time, we hope to cover more of these keyword arguments under the common interface.
34
34
35
- If a common argument is not implemented for a optimizer, a warning will be shown .
35
+ A warning will be shown if a common argument is not implemented for an optimizer.
36
36
37
37
## Callback Functions
38
38
39
- The callback function `callback` is a function which is called after every optimizer
39
+ The callback function `callback` is a function that is called after every optimizer
40
40
step. Its signature is:
41
41
42
42
```julia
43
43
callback = (state, loss_val) -> false
44
44
```
45
45
46
- where `state` is a `OptimizationState` and stores information for the current
46
+ where `state` is an `OptimizationState` and stores information for the current
47
47
iteration of the solver and `loss_val` is loss/objective value. For more
48
48
information about the fields of the `state` look at the `OptimizationState`
49
49
documentation. The callback should return a Boolean value, and the default
50
- should be `false`, such that the optimization gets stopped if it returns `true`.
50
+ should be `false`, so the optimization stops if it returns `true`.
51
51
52
52
### Callback Example
53
53
54
- Here we show an example a callback function that plots the prediction at the current value of the optimization variables.
55
- The loss function here returns the loss and the prediction i.e. the solution of the `ODEProblem` `prob`, so we can use the prediction in the callback.
54
+ Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
55
+ For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
56
+ So we call the `predict` function within the callback again.
56
57
57
58
```julia
58
59
function predict(u)
61
62
62
63
function loss(u, p)
63
64
pred = predict(u)
64
- sum(abs2, batch .- pred), pred
65
+ sum(abs2, batch .- pred)
65
66
end
66
67
67
68
callback = function (state, l; doplot = false) #callback function to observe training
0 commit comments