diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 1fe611cae..9f42cf07c 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -9,11 +9,13 @@ SciMLBase.requiresgradient(opt::AbstractRule) = true SciMLBase.allowsfg(opt::AbstractRule) = true function SciMLBase.__init( - prob::SciMLBase.OptimizationProblem, opt::AbstractRule; save_best = true, - callback = (args...) -> (false), epochs = nothing, - progress = false, kwargs...) - return OptimizationCache(prob, opt; save_best, callback, progress, epochs, - kwargs...) + prob::SciMLBase.OptimizationProblem, opt::AbstractRule; + callback = (args...) -> (false), + epochs::Union{Number, Nothing} = nothing, + maxiters::Union{Number, Nothing} = nothing, + save_best::Bool = true, progress::Bool = false, kwargs...) + return OptimizationCache(prob, opt; callback, epochs, maxiters, + save_best, progress, kwargs...) end function SciMLBase.__solve(cache::OptimizationCache{ @@ -50,20 +52,27 @@ function SciMLBase.__solve(cache::OptimizationCache{ dataiterate = false end - epochs = if cache.solver_args.epochs === nothing - if cache.solver_args.maxiters === nothing - throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) + epochs, maxiters = if isnothing(cache.solver_args.maxiters) && + isnothing(cache.solver_args.epochs) + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).")) + elseif !isnothing(cache.solver_args.maxiters) && + !isnothing(cache.solver_args.epochs) + if cache.solver_args.maxiters == cache.solver_args.epochs * length(data) + cache.solver_args.epochs, cache.solver_args.maxiters else - cache.solver_args.maxiters / length(data) + throw(ArgumentError("Both maxiters and epochs were passed but maxiters != epochs * length(data).")) end - else - cache.solver_args.epochs + elseif isnothing(cache.solver_args.maxiters) + cache.solver_args.epochs, cache.solver_args.epochs * length(data) + elseif isnothing(cache.solver_args.epochs) + cache.solver_args.maxiters / length(data), cache.solver_args.maxiters end - epochs = Optimization._check_and_convert_maxiters(epochs) - if epochs === nothing - throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) - end + maxiters = Optimization._check_and_convert_maxiters(maxiters) + + # At this point, both of them should be fine; but, let's assert it. + @assert (!isnothing(epochs)&&!isnothing(maxiters) && + (maxiters == epochs * length(data))) "The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)." opt = cache.opt θ = copy(cache.u0) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 43a1a1aaa..f84fa2bef 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -32,6 +32,25 @@ using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices @test sol.stats.fevals == 1000 @test sol.stats.gevals == 1000 + @testset "epochs & maxiters" begin + optprob = SciMLBase.OptimizationFunction( + (u, data) -> sum(u) + sum(data), Optimization.AutoZygote()) + prob = SciMLBase.OptimizationProblem( + optprob, ones(2), MLUtils.DataLoader(ones(2, 2))) + @test_throws ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).") solve( + prob, Optimisers.Adam()) + @test_throws ArgumentError("Both maxiters and epochs were passed but maxiters != epochs * length(data).") solve( + prob, Optimisers.Adam(), epochs = 2, maxiters = 2) + sol = solve(prob, Optimisers.Adam(), epochs = 2) + @test sol.stats.iterations == 4 + sol = solve(prob, Optimisers.Adam(), maxiters = 2) + @test sol.stats.iterations == 2 + sol = solve(prob, Optimisers.Adam(), epochs = 2, maxiters = 4) + @test sol.stats.iterations == 4 + @test_throws AssertionError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).") solve( + prob, Optimisers.Adam(), maxiters = 3) + end + @testset "cache" begin objective(x, p) = (p[1] - x[1])^2 x0 = zeros(1)