Skip to content

Commit 40888f1

Browse files
committed
Gaussian HMC Walk Polytope Normalization
1 parent 8a7818e commit 40888f1

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

include/convex_bodies/hpolytope.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class HPolytope {
5656
MT A; //matrix A
5757
VT b; // vector b, s.t.: Ax<=b
5858
std::pair<Point, NT> _inner_ball;
59+
bool normalized = 0;
5960

6061
public:
6162
//TODO: the default implementation of the Big3 should be ok. Recheck.
@@ -883,15 +884,24 @@ class HPolytope {
883884

884885
void normalize()
885886
{
886-
NT row_norm;
887-
for (int i = 0; i < num_of_hyperplanes(); ++i)
887+
if(!is_normalized())
888888
{
889-
row_norm = A.row(i).norm();
890-
A.row(i) = A.row(i) / row_norm;
891-
b(i) = b(i) / row_norm;
889+
NT row_norm;
890+
for (int i = 0; i < num_of_hyperplanes(); ++i)
891+
{
892+
row_norm = A.row(i).norm();
893+
A.row(i) = A.row(i) / row_norm;
894+
b(i) = b(i) / row_norm;
895+
}
896+
normalized = true;
892897
}
893898
}
894899

900+
bool is_normalized()
901+
{
902+
return normalized;
903+
}
904+
895905
void compute_reflection(Point& v, Point const&, int const& facet) const
896906
{
897907
v += -2 * v.dot(A.row(facet)) * A.row(facet);

include/random_walks/gaussian_hamiltonian_monte_carlo_exact_walk.hpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ struct Walk
5555
typedef typename Polytope::VT VT;
5656

5757
template <typename GenericPolytope>
58-
Walk(GenericPolytope &P, Point const& p, NT const& a_i, RandomNumberGenerator &rng)
58+
Walk(GenericPolytope& P, Point const& p, NT const& a_i, RandomNumberGenerator &rng)
5959
{
60+
P.normalize();
6061
_Len = compute_diameter<GenericPolytope>
6162
::template compute<NT>(P);
6263
_omega = std::sqrt(NT(2) * a_i);
@@ -65,9 +66,10 @@ struct Walk
6566
}
6667

6768
template <typename GenericPolytope>
68-
Walk(GenericPolytope &P, Point const& p, NT const& a_i, RandomNumberGenerator &rng,
69+
Walk(GenericPolytope& P, Point const& p, NT const& a_i, RandomNumberGenerator &rng,
6970
parameters const& params)
7071
{
72+
P.normalize();
7173
_Len = params.set_L ? params.m_L
7274
: compute_diameter<GenericPolytope>
7375
::template compute<NT>(P);
@@ -80,18 +82,16 @@ struct Walk
8082
<
8183
typename GenericPolytope
8284
>
83-
inline void apply(GenericPolytope& P,
85+
inline void apply(GenericPolytope const& P,
8486
Point& p,
8587
NT const& a_i,
8688
unsigned int const& walk_length,
8789
RandomNumberGenerator &rng)
8890
{
91+
8992
unsigned int n = P.dimension();
9093
NT T;
9194

92-
//normalize the Polyope
93-
P.normalize();
94-
9595
for (auto j=0u; j<walk_length; ++j)
9696
{
9797
T = rng.sample_urdist() * _Len;
@@ -124,7 +124,7 @@ struct Walk
124124
<
125125
typename GenericPolytope
126126
>
127-
inline void get_starting_point(GenericPolytope& P,
127+
inline void get_starting_point(GenericPolytope const& P,
128128
Point const& center,
129129
Point &q,
130130
unsigned int const& walk_length,
@@ -145,7 +145,7 @@ struct Walk
145145
<
146146
typename GenericPolytope
147147
>
148-
inline void parameters_burnin(GenericPolytope& P,
148+
inline void parameters_burnin(GenericPolytope const& P,
149149
Point const& center,
150150
unsigned int const& num_points,
151151
unsigned int const& walk_length,
@@ -195,7 +195,7 @@ private :
195195
<
196196
typename GenericPolytope
197197
>
198-
inline void initialize(GenericPolytope& P,
198+
inline void initialize(GenericPolytope const& P,
199199
Point const& p,
200200
NT const& a_i,
201201
RandomNumberGenerator &rng)
@@ -205,9 +205,6 @@ private :
205205
_p = p;
206206
_v = GetDirection<Point>::apply(n, rng, false);
207207

208-
//normalize the Polyope
209-
P.normalize();
210-
211208
NT T = rng.sample_urdist() * _Len;
212209
int it = 0;
213210

0 commit comments

Comments
 (0)