Skip to content

Commit d076bf0

Browse files
authored
Position Nudging after Position Update (#308)
* Position Nudging after Position Update * Complexity improvements and Polytope Normalization * HPolytope Normalization Flag * Polytope normalization within Walk Constructor * Alias HPolytope Normalization for Nudging inside Gaussian HMC * Polytope Normalization in ComputeInner Ball Fixed * Polytope Normalization Style change * Nudge in function within Gaussian HMC, and restore Hpoly file * More efficient Nudge in Process
1 parent 723869e commit d076bf0

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

include/convex_bodies/hpolytope.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,4 +927,4 @@ class HPolytope {
927927
}
928928
};
929929

930-
#endif
930+
#endif

include/random_walks/gaussian_hamiltonian_monte_carlo_exact_walk.hpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct Walk
5353
typedef typename Polytope::PointType Point;
5454
typedef typename Point::FT NT;
5555
typedef typename Polytope::VT VT;
56+
typedef typename Polytope::MT MT;
5657

5758
template <typename GenericPolytope>
5859
Walk(GenericPolytope &P, Point const& p, NT const& a_i, RandomNumberGenerator &rng)
@@ -80,7 +81,7 @@ struct Walk
8081
<
8182
typename GenericPolytope
8283
>
83-
inline void apply(GenericPolytope& P,
84+
inline void apply(GenericPolytope const& P,
8485
Point& p,
8586
NT const& a_i,
8687
unsigned int const& walk_length,
@@ -89,6 +90,9 @@ struct Walk
8990
unsigned int n = P.dimension();
9091
NT T;
9192

93+
GenericPolytope P_normalized = P;
94+
P_normalized.normalize();
95+
9296
for (auto j=0u; j<walk_length; ++j)
9397
{
9498
T = rng.sample_urdist() * _Len;
@@ -105,6 +109,7 @@ struct Walk
105109
_lambda_prev = pbpair.first;
106110
T -= _lambda_prev;
107111
update_position(_p, _v, _lambda_prev, _omega);
112+
nudge_in(P_normalized, _p);
108113
P.compute_reflection(_v, _p, pbpair.second);
109114
it++;
110115
}
@@ -120,7 +125,7 @@ struct Walk
120125
<
121126
typename GenericPolytope
122127
>
123-
inline void get_starting_point(GenericPolytope& P,
128+
inline void get_starting_point(GenericPolytope const& P,
124129
Point const& center,
125130
Point &q,
126131
unsigned int const& walk_length,
@@ -141,7 +146,7 @@ struct Walk
141146
<
142147
typename GenericPolytope
143148
>
144-
inline void parameters_burnin(GenericPolytope& P,
149+
inline void parameters_burnin(GenericPolytope const& P,
145150
Point const& center,
146151
unsigned int const& num_points,
147152
unsigned int const& walk_length,
@@ -191,7 +196,7 @@ private :
191196
<
192197
typename GenericPolytope
193198
>
194-
inline void initialize(GenericPolytope& P,
199+
inline void initialize(GenericPolytope const& P,
195200
Point const& p,
196201
NT const& a_i,
197202
RandomNumberGenerator &rng)
@@ -204,6 +209,9 @@ private :
204209
NT T = rng.sample_urdist() * _Len;
205210
int it = 0;
206211

212+
GenericPolytope P_normalized = P;
213+
P_normalized.normalize();
214+
207215
while (it <= _rho)
208216
{
209217
auto pbpair
@@ -218,12 +226,50 @@ private :
218226
}
219227
_lambda_prev = pbpair.first;
220228
update_position(_p, _v, _lambda_prev, _omega);
229+
nudge_in(P_normalized, _p);
221230
T -= _lambda_prev;
222231
P.compute_reflection(_v, _p, pbpair.second);
223232
it++;
224233
}
225234
}
226235

236+
template
237+
<
238+
typename GenericPolytope
239+
>
240+
inline void nudge_in(GenericPolytope& P, Point& p, NT tol=NT(0))
241+
{
242+
MT A = P.get_mat();
243+
VT b = P.get_vec();
244+
int m = A.rows();
245+
246+
VT b_Ax = b - A * p.getCoefficients();
247+
const NT* b_Ax_data = b_Ax.data();
248+
249+
NT dist;
250+
251+
for (int i = 0; i < m; i++) {
252+
253+
dist = *b_Ax_data;
254+
255+
if (dist < NT(-tol)){
256+
//Nudging correction
257+
NT eps = -1e-7;
258+
259+
NT eps_1 = -dist;
260+
//A.row is already normalized, no need to do it again
261+
VT A_i = A.row(i);
262+
NT eps_2 = eps_1 + eps;
263+
264+
//Nudge the point inside with respect to the normal its vector
265+
Point shift(A_i);
266+
shift.operator*=(eps_2);
267+
p.operator+=(shift);
268+
}
269+
b_Ax_data++;
270+
}
271+
}
272+
227273
inline void update_position(Point &p, Point &v, NT const& T, NT const& omega)
228274
{
229275
NT next_p, next_v;
@@ -274,4 +320,3 @@ private :
274320

275321

276322
#endif // RANDOM_WALKS_GAUSSIAN_HMC_WALK_HPP
277-

0 commit comments

Comments
 (0)