From da12a241bf673923732d4253f33348e80ac73952 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 12 Apr 2025 20:29:11 +0330 Subject: [PATCH 1/8] define `maxiters` similar to `epochs` --- .../src/OptimizationOptimisers.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 1fe611cae..c8ede3a48 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -59,9 +59,19 @@ function SciMLBase.__solve(cache::OptimizationCache{ else cache.solver_args.epochs end + maxiters = if cache.solver_args.maxiters === nothing + if cache.solver_args.epochs === nothing + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) + else + cache.solver_args.epochs * length(data) + end + else + cache.solver_args.maxiters + end epochs = Optimization._check_and_convert_maxiters(epochs) - if epochs === nothing + maxiters = Optimization._check_and_convert_maxiters(maxiters) + if epochs === nothing || maxiters === nothing throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) end From e96800dcb6744af5a4aec5b190c18198d036163b Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 19 Apr 2025 18:48:33 +0330 Subject: [PATCH 2/8] better organization of the codes --- .../src/OptimizationOptimisers.jl | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index c8ede3a48..40f4af728 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -50,30 +50,20 @@ function SciMLBase.__solve(cache::OptimizationCache{ dataiterate = false end - epochs = if cache.solver_args.epochs === nothing - if cache.solver_args.maxiters === nothing - throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) - else - cache.solver_args.maxiters / length(data) - end - else - cache.solver_args.epochs - end - maxiters = if cache.solver_args.maxiters === nothing - if cache.solver_args.epochs === nothing - throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) - else - cache.solver_args.epochs * length(data) - end - else - cache.solver_args.maxiters + epochs, maxiters = if isnothing(cache.solver_args.maxiters) == + isnothing(cache.solver_args.epochs) + # both of them are `nothing` or have a value + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).")) + elseif isnothing(cache.solver_args.maxiters) + cache.solver_args.epochs, cache.solver_args.epochs * length(data) + elseif isnothing(cache.solver_args.epochs) + cache.solver_args.maxiters / length(data), cache.solver_args.maxiters end - epochs = Optimization._check_and_convert_maxiters(epochs) maxiters = Optimization._check_and_convert_maxiters(maxiters) - if epochs === nothing || maxiters === nothing - throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) - end + + # At this point, both of them should be fine; but, let's assert it. + @assert (isnothing(epochs)||isnothing(maxiters) || (maxiters != epochs * length(data))) "The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)." opt = cache.opt θ = copy(cache.u0) From cc03eae5d4b030b7a293a1759733fc1d0fa69943 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 19 Apr 2025 18:49:44 +0330 Subject: [PATCH 3/8] add types to the kwargs --- .../src/OptimizationOptimisers.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 40f4af728..784f5d0dd 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -9,11 +9,13 @@ SciMLBase.requiresgradient(opt::AbstractRule) = true SciMLBase.allowsfg(opt::AbstractRule) = true function SciMLBase.__init( - prob::SciMLBase.OptimizationProblem, opt::AbstractRule; save_best = true, - callback = (args...) -> (false), epochs = nothing, - progress = false, kwargs...) - return OptimizationCache(prob, opt; save_best, callback, progress, epochs, - kwargs...) + prob::SciMLBase.OptimizationProblem, opt::AbstractRule; + callback = (args...) -> (false), + epochs::Union{Number, Nothing} = nothing, + maxiters::Union{Number, Nothing} = nothing, + save_best::Bool = true, progress::Bool = false, kwargs...) + return OptimizationCache(prob, opt; callback, epochs, maxiters, + save_best, progress, kwargs...) end function SciMLBase.__solve(cache::OptimizationCache{ From 075bc0fcfbed410a4b30cdf5437f8182873b6a25 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 19 Apr 2025 19:04:46 +0330 Subject: [PATCH 4/8] check for right numbers --- .../src/OptimizationOptimisers.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 784f5d0dd..276e37d72 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -52,10 +52,16 @@ function SciMLBase.__solve(cache::OptimizationCache{ dataiterate = false end - epochs, maxiters = if isnothing(cache.solver_args.maxiters) == + epochs, maxiters = if isnothing(cache.solver_args.maxiters) && isnothing(cache.solver_args.epochs) - # both of them are `nothing` or have a value throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).")) + elseif !isnothing(cache.solver_args.maxiters) && + !isnothing(cache.solver_args.epochs) + if cache.solver_args.maxiters == cache.solver_args.epochs * length(data) + cache.solver_args.epochs, cache.solver_args.maxiters + else + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).")) + end elseif isnothing(cache.solver_args.maxiters) cache.solver_args.epochs, cache.solver_args.epochs * length(data) elseif isnothing(cache.solver_args.epochs) From 28302c8c4fa0eedc6873a67594bbeb4b7e1cba4d Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 19 Apr 2025 19:14:20 +0330 Subject: [PATCH 5/8] add tests --- lib/OptimizationOptimisers/test/runtests.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 43a1a1aaa..6d737525e 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -32,6 +32,19 @@ using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices @test sol.stats.fevals == 1000 @test sol.stats.gevals == 1000 + @testset "epochs & maxiters" begin + optprob = SciMLBase.OptimizationFunction( + (u, data) -> sum(u) + sum(data), Optimization.AutoZygote()) + prob = SciMLBase.OptimizationProblem(optprob, ones(2), ones(2)) + @test_throws ArgumentError solve(prob, Optimisers.Adam()) + @test_throws ArgumentError solve( + prob, Optimisers.Adam(), epochs = 2, maxiters = 2) + @test solve(prob, Optimisers.Adam(), epochs = 2) + @test solve(prob, Optimisers.Adam(), maxiters = 2) + @test solve(prob, Optimisers.Adam(), epochs = 2, maxiters = 4) + @test_throws AssertionError solve(prob, Optimisers.Adam(), maxiters = 3) + end + @testset "cache" begin objective(x, p) = (p[1] - x[1])^2 x0 = zeros(1) From 243f84823c655badca51b6d48611a6e2167b5850 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 19 Apr 2025 19:35:05 +0330 Subject: [PATCH 6/8] Update lib/OptimizationOptimisers/src/OptimizationOptimisers.jl Co-authored-by: Vaibhav Kumar Dixit --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 276e37d72..cdb6b7dd9 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -60,7 +60,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ if cache.solver_args.maxiters == cache.solver_args.epochs * length(data) cache.solver_args.epochs, cache.solver_args.maxiters else - throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).")) + throw(ArgumentError("Both maxiters and epochs were passed but maxiters != epochs * length(data).")) end elseif isnothing(cache.solver_args.maxiters) cache.solver_args.epochs, cache.solver_args.epochs * length(data) From 692b6606d5f41063b7d20aa98de60e3a010d5661 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 19 Apr 2025 19:41:59 +0330 Subject: [PATCH 7/8] add messages to the error types in tests --- lib/OptimizationOptimisers/test/runtests.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 6d737525e..038e40f81 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -36,13 +36,18 @@ using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices optprob = SciMLBase.OptimizationFunction( (u, data) -> sum(u) + sum(data), Optimization.AutoZygote()) prob = SciMLBase.OptimizationProblem(optprob, ones(2), ones(2)) - @test_throws ArgumentError solve(prob, Optimisers.Adam()) - @test_throws ArgumentError solve( + @test_throws ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).") solve( + prob, Optimisers.Adam()) + @test_throws ArgumentError("Both maxiters and epochs were passed but maxiters != epochs * length(data).") solve( prob, Optimisers.Adam(), epochs = 2, maxiters = 2) - @test solve(prob, Optimisers.Adam(), epochs = 2) - @test solve(prob, Optimisers.Adam(), maxiters = 2) - @test solve(prob, Optimisers.Adam(), epochs = 2, maxiters = 4) - @test_throws AssertionError solve(prob, Optimisers.Adam(), maxiters = 3) + sol = solve(prob, Optimisers.Adam(), epochs = 2) + @test sol.stats.iterations == 4 + sol = solve(prob, Optimisers.Adam(), maxiters = 2) + @test sol.stats.iterations == 2 + sol = solve(prob, Optimisers.Adam(), epochs = 2, maxiters = 4) + @test sol.stats.iterations == 4 + @test_throws AssertionError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).") solve( + prob, Optimisers.Adam(), maxiters = 3) end @testset "cache" begin From 5926cd85f324cf2635c56c61ea0e2f5ec52d3b73 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 19 Apr 2025 21:13:34 +0330 Subject: [PATCH 8/8] fix errors --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 3 ++- lib/OptimizationOptimisers/test/runtests.jl | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index cdb6b7dd9..9f42cf07c 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -71,7 +71,8 @@ function SciMLBase.__solve(cache::OptimizationCache{ maxiters = Optimization._check_and_convert_maxiters(maxiters) # At this point, both of them should be fine; but, let's assert it. - @assert (isnothing(epochs)||isnothing(maxiters) || (maxiters != epochs * length(data))) "The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)." + @assert (!isnothing(epochs)&&!isnothing(maxiters) && + (maxiters == epochs * length(data))) "The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)." opt = cache.opt θ = copy(cache.u0) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 038e40f81..f84fa2bef 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -35,7 +35,8 @@ using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices @testset "epochs & maxiters" begin optprob = SciMLBase.OptimizationFunction( (u, data) -> sum(u) + sum(data), Optimization.AutoZygote()) - prob = SciMLBase.OptimizationProblem(optprob, ones(2), ones(2)) + prob = SciMLBase.OptimizationProblem( + optprob, ones(2), MLUtils.DataLoader(ones(2, 2))) @test_throws ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).") solve( prob, Optimisers.Adam()) @test_throws ArgumentError("Both maxiters and epochs were passed but maxiters != epochs * length(data).") solve(