Skip to content

Commit eef76f9

Browse files
Merge pull request #887 from SciML/ChrisRackauckas-patch-1
Make DiffEqFlux test more robust
2 parents bf83617 + 0d6ed49 commit eef76f9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/diffeqfluxtests.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
7070
dudt2 = Lux.Chain(x -> x .^ 3,
7171
Lux.Dense(2, 50, tanh),
7272
Lux.Dense(50, 2))
73-
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
73+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
7474
pp, st = Lux.setup(rng, dudt2)
7575
pp = ComponentArray(pp)
7676

@@ -99,13 +99,13 @@ prob = Optimization.OptimizationProblem(optprob, pp)
9999

100100
result_neuralode = Optimization.solve(prob,
101101
OptimizationOptimisers.ADAM(), callback = callback,
102-
maxiters = 300)
102+
maxiters = 1000)
103103
@test result_neuralode.objectiveloss_neuralode(result_neuralode.u)[1] rtol=1e-2
104104

105105
prob2 = remake(prob, u0 = result_neuralode.u)
106106
result_neuralode2 = Optimization.solve(prob2,
107107
BFGS(initial_stepnorm = 0.0001),
108108
callback = callback,
109-
maxiters = 100)
109+
maxiters = 300, allow_f_increases = true)
110110
@test result_neuralode2.objectiveloss_neuralode(result_neuralode2.u)[1] rtol=1e-2
111111
@test result_neuralode2.objective < 10

0 commit comments

Comments
 (0)