Skip to content

Commit 62e6ced

Browse files
committed
Implement moment curve formulation for SOS2
1 parent 78e84cb commit 62e6ced

File tree

3 files changed

+151
-4
lines changed

3 files changed

+151
-4
lines changed

src/PiecewiseLinearOpt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module PiecewiseLinearOpt
22

3-
import JuMP
3+
import JuMP, MathProgBase, CPLEX
44

55
export PWLFunction, UnivariatePWLFunction, BivariatePWLFunction, piecewiselinear
66

src/jump.jl

+89-3
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,36 @@ defaultmethod() = :Logarithmic
33

44
type PWLData
55
counter::Int
6-
PWLData() = new(0)
6+
branchvars::Vector{Int}
7+
PWLData() = new(0,Int[])
78
end
89

910
function initPWL!(m::JuMP.Model)
1011
if !haskey(m.ext, :PWL)
1112
m.ext[:PWL] = PWLData()
13+
JuMP.setsolvehook(m, function solvehook(m; kwargs...)
14+
if !isempty(m.ext[:PWL].branchvars)
15+
JuMP.build(m)
16+
function branchcallback(d::MathProgBase.MathProgCallbackData)
17+
state = MathProgBase.cbgetstate(d)
18+
if state == :MIPSol
19+
MathProgBase.cbgetmipsolution(d,m.colVal)
20+
else
21+
MathProgBase.cbgetlpsolution(d,m.colVal)
22+
end
23+
moment_curve_branch_callback(m, d)
24+
end
25+
CPLEX.setbranchcallback!(m.internalModel, branchcallback)
26+
function incumbentcallback(d::MathProgBase.MathProgCallbackData)
27+
state = MathProgBase.cbgetstate(d)
28+
@assert state == :MIPIncumbent
29+
m.colVal = copy(d.sol)
30+
moment_curve_incumbent_callback(m, d)
31+
end
32+
CPLEX.setincumbentcallback!(m.internalModel, incumbentcallback)
33+
end
34+
JuMP.solve(m; ignore_solve_hook=true, kwargs...)
35+
end)
1236
end
1337
nothing
1438
end
@@ -76,8 +100,8 @@ function piecewiselinear(m::JuMP.Model, x::JuMP.Variable, pwl::UnivariatePWLFunc
76100
sos2_symmetric_celaya_formulation!(m, λ)
77101
elseif method == :SOS2
78102
JuMP.addSOS2(m, [i*λ[i] for i in 1:n])
79-
else
80-
error("Unrecognized method $method")
103+
elseif method == :MomentCurve
104+
sos2_moment_curve_formulation!(m, λ)
81105
end
82106
end
83107
z
@@ -575,3 +599,65 @@ function piecewiselinear(m::JuMP.Model, x₁::JuMP.Variable, x₂::JuMP.Variable
575599
end
576600
z
577601
end
602+
603+
function sos2_moment_curve_formulation!(m::JuMP.Model, λ)
604+
counter = m.ext[:PWL].counter
605+
d = length(λ)-1
606+
y = JuMP.@variable(m, [i=1:2], Int, lowerbound=1, upperbound=d^i, basename="y_$counter")
607+
for i in 1:d
608+
JuMP.@constraints(m, begin
609+
-2i*λ[1] + sum((v^2-(2i+1)*v+2min(0,i+1-v))*λ[v] for v in 2:d) + (d^2-(2i+1)*d)λ[d+1] -(2i+1)*y[1] + y[2]
610+
-2i*λ[1] + sum((v^2-(2i+1)*v+2max(0,i+1-v))*λ[v] for v in 2:d) + (d^2-(2i+1)*d)λ[d+1] -(2i+1)*y[1] + y[2]
611+
end)
612+
end
613+
push!(m.ext[:PWL].branchvars, JuMP.linearindex(y[1]))
614+
nothing
615+
end
616+
617+
function sos2_moment_curve_formulation!(m::JuMP.Model, λ)
618+
counter = m.ext[:PWL].counter
619+
d = length(λ)-1
620+
y = JuMP.@variable(m, [i=1:2], Int, lowerbound=1, upperbound=d^i, basename="y_$counter")
621+
for i in 1:d
622+
JuMP.@constraints(m, begin
623+
-2i*λ[1] + sum((v^2-(2i+1)*v+2min(0,i+1-v))*λ[v] for v in 2:d) + (d^2-(2i+1)*d)λ[d+1] -(2i+1)*y[1] + y[2]
624+
-2i*λ[1] + sum((v^2-(2i+1)*v+2max(0,i+1-v))*λ[v] for v in 2:d) + (d^2-(2i+1)*d)λ[d+1] -(2i+1)*y[1] + y[2]
625+
end)
626+
end
627+
push!(m.ext[:PWL].branchvars, JuMP.linearindex(y[1]))
628+
nothing
629+
end
630+
631+
function moment_curve_branch_callback(m, cb)
632+
# if CPLEX was gonna branch anyway, just use their branching decision
633+
if !isempty(cb.nodes)
634+
unsafe_store!(cb.userinteraction_p, Cint(0))
635+
return nothing
636+
end
637+
xval = MathProgBase.cbgetlpsolution(cb)
638+
TOL = 1e-4
639+
for i in branchvars
640+
if (ceil(xval[i]) - xval[i] > TOL) && (xval[i]-floor(xval[i]) > TOL)
641+
branch_ind = i
642+
y = [JuMP.Variable(m, i), JuMP.Variable(m, i+1)]
643+
break
644+
end
645+
end
646+
l, u = MathProgBase.cbgetnodelb(cb), MathProgBase.cbgetnodeub(cb)
647+
uᶠ, lᶜ = floor(xval[branch_id]), ceil(xval[branch_id])
648+
addbranch(cb, (uᶠ-l )*(y[2]-l ^2) (uᶠ^2-l ^2)*(y[1]-l ))
649+
addbranch(cb, (u -lᶜ)*(y[2]-lᶜ^2) (u ^2-lᶜ^2)*(y[1]-lᶜ))
650+
nothing
651+
end
652+
653+
function moment_curve_incumbent_callback(m, cb)
654+
xval = MathProgBase.cbgetmipsolution(cb)
655+
for i in m.ext[:PWL].branchvars
656+
if !isapprox(xval[i]^2, xval[i+1], rtol=1e-4)
657+
CPLEX.rejectincumbent(cb)
658+
return nothing
659+
end
660+
end
661+
CPLEX.acceptincumbent(cb)
662+
return nothing
663+
end

test/pwl-trials.jl

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using PiecewiseLinearOpt
2+
using Base.Test
3+
4+
using JuMP, CPLEX
5+
6+
const solver = CplexSolver(CPX_PARAM_TILIM=30*60.0, CPX_PARAM_MIPCBREDLP=0)
7+
8+
fp = open("1D-pwl-results.csv", "w+")
9+
fp2 = open("1D-pwl-objective-value.csv", "w+")
10+
11+
methods = (:MomentCurve,:Incremental,:MC,:CC,:Logarithmic)
12+
13+
println(fp, "instance, ", join(methods, ", "))
14+
println(fp2, "instance, ", join(methods, ", "))
15+
16+
for instance in readdir(joinpath(Pkg.dir("PiecewiseLinearOpt"),"test","1D-pwl-instances"))
17+
print(fp, "$instance")
18+
print(fp2, "$instance")
19+
20+
folder = joinpath(Pkg.dir("PiecewiseLinearOpt"),"test","1D-pwl-instances",instance)
21+
22+
demand = readdlm(joinpath(folder, "dem.dat"))
23+
supply = readdlm(joinpath(folder, "sup.dat"))
24+
numdem = size(demand, 1)
25+
numsup = size(supply, 1)
26+
27+
d = readdlm(joinpath(folder, "mat.dat"))
28+
fd = readdlm(joinpath(folder, "obj.dat"))
29+
K = size(d, 2)
30+
31+
for method in methods
32+
model = Model(solver=solver)
33+
@variable(model, x[1:numsup,1:numdem] 0)
34+
for j in 1:numdem
35+
# demand constraint
36+
@constraint(model, sum(x[i,j] for i in 1:numsup) == demand[j])
37+
end
38+
for i in 1:numsup
39+
# supply constraint
40+
@constraint(model, sum(x[i,j] for j in 1:numdem) == supply[i])
41+
end
42+
43+
idx = 1
44+
obj = AffExpr()
45+
for i in 1:numsup, j in 1:numdem
46+
z = piecewiselinear(model, x[i,j], d[idx,:], fd[idx,:], method=method)
47+
obj += z
48+
end
49+
@objective(model, Min, obj)
50+
51+
tm = @elapsed solve(model)
52+
print(fp, ", $tm")
53+
flush(fp)
54+
print(fp2, ", $(getobjectivevalue(model))")
55+
flush(fp2)
56+
57+
error()
58+
end
59+
println(fp)
60+
println(fp2)
61+
end

0 commit comments

Comments
 (0)