Skip to content

Commit e96800d

Browse files
committed
better organization of the codes
1 parent da12a24 commit e96800d

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

+11-21
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,20 @@ function SciMLBase.__solve(cache::OptimizationCache{
5050
dataiterate = false
5151
end
5252

53-
epochs = if cache.solver_args.epochs === nothing
54-
if cache.solver_args.maxiters === nothing
55-
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
56-
else
57-
cache.solver_args.maxiters / length(data)
58-
end
59-
else
60-
cache.solver_args.epochs
61-
end
62-
maxiters = if cache.solver_args.maxiters === nothing
63-
if cache.solver_args.epochs === nothing
64-
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
65-
else
66-
cache.solver_args.epochs * length(data)
67-
end
68-
else
69-
cache.solver_args.maxiters
53+
epochs, maxiters = if isnothing(cache.solver_args.maxiters) ==
54+
isnothing(cache.solver_args.epochs)
55+
# both of them are `nothing` or have a value
56+
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)."))
57+
elseif isnothing(cache.solver_args.maxiters)
58+
cache.solver_args.epochs, cache.solver_args.epochs * length(data)
59+
elseif isnothing(cache.solver_args.epochs)
60+
cache.solver_args.maxiters / length(data), cache.solver_args.maxiters
7061
end
71-
7262
epochs = Optimization._check_and_convert_maxiters(epochs)
7363
maxiters = Optimization._check_and_convert_maxiters(maxiters)
74-
if epochs === nothing || maxiters === nothing
75-
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
76-
end
64+
65+
# At this point, both of them should be fine; but, let's assert it.
66+
@assert (isnothing(epochs)||isnothing(maxiters) || (maxiters != epochs * length(data))) "The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)."
7767

7868
opt = cache.opt
7969
θ = copy(cache.u0)

0 commit comments

Comments
 (0)