Skip to content

Commit 0d6ed49

Browse files
Update diffeqfluxtests.jl
1 parent 3aba091 commit 0d6ed49

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/diffeqfluxtests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
7070
dudt2 = Lux.Chain(x -> x .^ 3,
7171
Lux.Dense(2, 50, tanh),
7272
Lux.Dense(50, 2))
73-
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
73+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
7474
pp, st = Lux.setup(rng, dudt2)
7575
pp = ComponentArray(pp)
7676

@@ -99,7 +99,7 @@ prob = Optimization.OptimizationProblem(optprob, pp)
9999

100100
result_neuralode = Optimization.solve(prob,
101101
OptimizationOptimisers.ADAM(), callback = callback,
102-
maxiters = 500)
102+
maxiters = 1000)
103103
@test result_neuralode.objectiveloss_neuralode(result_neuralode.u)[1] rtol=1e-2
104104

105105
prob2 = remake(prob, u0 = result_neuralode.u)

0 commit comments

Comments
 (0)