Skip to content

Commit 7acf686

Browse files
committed
implement vaidya minimizer and hessian computation
1 parent e6dd7fd commit 7acf686

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

include/preprocess/inscribed_ellipsoid_rounding.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ compute_inscribed_ellipsoid(Custom_MT A, VT b, VT const& x0,
2626
{
2727
return max_inscribed_ellipsoid<MT>(A, b, x0, maxiter, tol, reg);
2828
} else if constexpr (ellipsoid_type == EllipsoidType::LOG_BARRIER ||
29-
ellipsoid_type == EllipsoidType::VOLUMETRIC_BARRIER)
29+
ellipsoid_type == EllipsoidType::VOLUMETRIC_BARRIER ||
30+
ellipsoid_type == EllipsoidType::VAIDYA_BARRIER)
3031
{
3132
return barrier_center_ellipsoid_linear_ineq<MT, ellipsoid_type, NT>(A, b, x0);
3233
} else

include/preprocess/rounding_util_functions.hpp

+33-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ enum EllipsoidType
2222
{
2323
MAX_ELLIPSOID = 1,
2424
LOG_BARRIER = 2,
25-
VOLUMETRIC_BARRIER = 3
25+
VOLUMETRIC_BARRIER = 3,
26+
VAIDYA_BARRIER = 4
2627
};
2728

2829
template <int T>
@@ -345,7 +346,8 @@ std::tuple<NT, NT> init_step()
345346
if constexpr (BarrierType == EllipsoidType::LOG_BARRIER)
346347
{
347348
return {NT(1), NT(0.99)};
348-
} else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER)
349+
} else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER ||
350+
BarrierType == EllipsoidType::VAIDYA_BARRIER)
349351
{
350352
return {NT(0.5), NT(0.4)};
351353
} else {
@@ -362,21 +364,43 @@ void get_barrier_hessian_grad(MT const& A, MT const& A_trans, VT const& b,
362364
b_Ax.noalias() = b - Ax;
363365
VT s = b_Ax.cwiseInverse();
364366
VT s_sq = s.cwiseProduct(s);
367+
VT sigma;
365368
// Hessian of the log-barrier function
366369
update_Atrans_Diag_A<NT>(H, A_trans, A, s_sq.asDiagonal());
370+
371+
if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER ||
372+
BarrierType == EllipsoidType::VAIDYA_BARRIER)
373+
{
374+
// Computing sigma(x)_i = (a_i^T H^{-1} a_i) / (b_i - a_i^Tx)^2
375+
MT_dense HA = solve_mat(llt, H, A_trans, obj_val);
376+
MT_dense aiHai = HA.transpose().cwiseProduct(A);
377+
sigma = (aiHai.rowwise().sum()).cwiseProduct(s_sq);
378+
}
379+
367380
if constexpr (BarrierType == EllipsoidType::LOG_BARRIER)
368381
{
369382
grad.noalias() = A_trans * s;
370383
} else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER)
371384
{
372-
// Computing sigma(x)_i = (a_i^T H^{-1} a_i) / (b_i - a_i^Tx)^2
373-
MT_dense HA = solve_mat(llt, H, A_trans, obj_val);
374-
MT_dense aiHai = HA.transpose().cwiseProduct(A);
375-
VT sigma = (aiHai.rowwise().sum()).cwiseProduct(s_sq);
376385
// Gradient of the volumetric barrier function
377386
grad.noalias() = A_trans * (s.cwiseProduct(sigma));
378387
// Hessian of the volumetric barrier function
379388
update_Atrans_Diag_A<NT>(H, A_trans, A, s_sq.cwiseProduct(sigma).asDiagonal());
389+
} else if constexpr (BarrierType == EllipsoidType::VAIDYA_BARRIER)
390+
{
391+
const int m = b.size(), d = x.size();
392+
// Weighted gradient of the log barrier function
393+
grad.noalias() = A_trans * s;
394+
grad *= NT(d) / NT(m);
395+
// Add the gradient of the volumetric function
396+
grad.noalias() += A_trans * (s.cwiseProduct(sigma));
397+
// Weighted Hessian of the log barrier function
398+
H *= NT(d) / NT(m);
399+
// Add the Hessian of the volumetric function
400+
MT Hvol(d, d);
401+
update_Atrans_Diag_A<NT>(Hvol, A_trans, A, s_sq.cwiseProduct(sigma).asDiagonal());
402+
H += Hvol;
403+
obj_val -= s.array().log().sum();
380404
} else {
381405
static_assert(AssertBarrierFalseType<BarrierType>::value,
382406
"Barrier type is not supported.");
@@ -393,6 +417,9 @@ void get_step_next_iteration(NT const obj_val_prev, NT const obj_val,
393417
} else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER)
394418
{
395419
step_iter *= (obj_val_prev <= obj_val - tol_obj) ? NT(0.9) : NT(0.999);
420+
} else if constexpr (BarrierType == EllipsoidType::VAIDYA_BARRIER)
421+
{
422+
step_iter *= NT(0.999);
396423
} else {
397424
static_assert(AssertBarrierFalseType<BarrierType>::value,
398425
"Barrier type is not supported.");

0 commit comments

Comments
 (0)