Skip to content

Commit 7a74b3b

Browse files
Merge pull request #68 from SciML/dataarg
Don't use DiffResults in Flux optimiser dispatch with FiniteDiff
2 parents 2d585bb + 69127dd commit 7a74b3b

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GalacticOptim"
22
uuid = "a75be94c-b780-496d-a8a9-0878b188d577"
33
authors = ["Vaibhavdixit02 <[email protected]>"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,13 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
212212
_f = (θ, args...) -> first(f.f(θ, p, args...))
213213

214214
if f.grad === nothing
215-
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res,x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
215+
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res, x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
216216
else
217217
grad = f.grad
218218
end
219219

220220
if f.hess === nothing
221-
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res,x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
221+
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res, x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
222222
else
223223
hess = f.hess
224224
end

src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args
9494

9595
@withprogress progress name="Training" begin
9696
for (i,d) in enumerate(data)
97-
gs = DiffResults.GradientResult(θ)
97+
gs = prob.f.adtype isa AutoFiniteDiff ? Array{Number}(undef,length(θ)) : DiffResults.GradientResult(θ)
9898
f.grad(gs, θ, d...)
9999
x = f.f(θ, prob.p, d...)
100100
cb_call = cb(θ, x...)
@@ -105,7 +105,7 @@ function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args
105105
end
106106
msg = @sprintf("loss: %.3g", x[1])
107107
progress && ProgressLogging.@logprogress msg i/maxiters
108-
update!(opt, ps, DiffResults.gradient(gs))
108+
update!(opt, ps, prob.f.adtype isa AutoFiniteDiff ? gs : DiffResults.gradient(gs))
109109

110110
if save_best
111111
if first(x) < first(min_err) #found a better solution

test/ADtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,6 @@ sol = solve(prob, Newton())
107107

108108
sol = solve(prob, Optim.KrylovTrustRegion())
109109
@test sol.minimum < l1 #the loss doesn't go below 5e-1 here
110+
111+
sol = solve(prob, ADAM(0.1))
112+
@test 10*sol.minimum < l1

0 commit comments

Comments
 (0)