@@ -72,7 +72,8 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
72
72
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dtype</ span >
73
73
< span class ="k "> if</ span > < span class ="n "> maps</ span > < span class ="p "> :</ span >
74
74
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> ==</ span > < span class ="n "> maps</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="p "> :</ span >
75
- < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="s2 "> "Float type of maps should be same as integrator."</ span > < span class ="p "> )</ span >
75
+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span >
76
+ < span class ="s2 "> "Float type of maps should be same as integrator."</ span > < span class ="p "> )</ span >
76
77
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> bounds</ span > < span class ="o "> =</ span > < span class ="n "> maps</ span > < span class ="o "> .</ span > < span class ="n "> bounds</ span >
77
78
< span class ="k "> else</ span > < span class ="p "> :</ span >
78
79
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> bounds</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="nb "> list</ span > < span class ="p "> ,</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> )):</ span >
@@ -123,6 +124,7 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
123
124
< span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="s2 "> "cpu"</ span > < span class ="p "> ,</ span >
124
125
< span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> float64</ span > < span class ="p "> ,</ span >
125
126
< span class ="p "> ):</ span >
127
+
126
128
< span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="n "> maps</ span > < span class ="p "> ,</ span > < span class ="n "> bounds</ span > < span class ="p "> ,</ span > < span class ="n "> q0</ span > < span class ="p "> ,</ span > < span class ="n "> neval</ span > < span class ="p "> ,</ span > < span class ="n "> nbatch</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="p "> )</ span >
127
129
128
130
< span class ="k "> def</ span > < span class ="fm "> __call__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> :</ span > < span class ="n "> Callable</ span > < span class ="p "> ,</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> ):</ span >
@@ -142,12 +144,14 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
142
144
< span class ="p "> )</ span >
143
145
144
146
< span class ="n "> epoch</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> neval</ span > < span class ="o "> //</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span >
145
- < span class ="n "> values</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> ((</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span > < span class ="p "> ,</ span > < span class ="n "> f_size</ span > < span class ="p "> ),</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> type_fval</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
147
+ < span class ="n "> values</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> ((</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span > < span class ="p "> ,</ span > < span class ="n "> f_size</ span > < span class ="p "> ),</ span >
148
+ < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> type_fval</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
146
149
147
150
< span class ="k "> for</ span > < span class ="n "> iepoch</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> epoch</ span > < span class ="p "> ):</ span >
148
151
< span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> log_detJ</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> sample</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span > < span class ="p "> )</ span >
149
152
< span class ="n "> f_values</ span > < span class ="o "> =</ span > < span class ="n "> f</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> )</ span >
150
- < span class ="n "> batch_results</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _multiply_by_jacobian</ span > < span class ="p "> (</ span > < span class ="n "> f_values</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> exp</ span > < span class ="p "> (</ span > < span class ="n "> log_detJ</ span > < span class ="p "> ))</ span >
153
+ < span class ="n "> batch_results</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _multiply_by_jacobian</ span > < span class ="p "> (</ span >
154
+ < span class ="n "> f_values</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> exp</ span > < span class ="p "> (</ span > < span class ="n "> log_detJ</ span > < span class ="p "> ))</ span >
151
155
152
156
< span class ="n "> values</ span > < span class ="o "> +=</ span > < span class ="n "> batch_results</ span > < span class ="o "> /</ span > < span class ="n "> epoch</ span >
153
157
@@ -177,7 +181,8 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
177
181
< span class ="n "> rangebounds</ span > < span class ="o "> =</ span > < span class ="n "> bounds</ span > < span class ="p "> [:,</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="o "> -</ span > < span class ="n "> bounds</ span > < span class ="p "> [:,</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
178
182
< span class ="n "> step_size</ span > < span class ="o "> =</ span > < span class ="n "> kwargs</ span > < span class ="o "> .</ span > < span class ="n "> get</ span > < span class ="p "> (</ span > < span class ="s2 "> "step_size"</ span > < span class ="p "> ,</ span > < span class ="mf "> 0.2</ span > < span class ="p "> )</ span >
179
183
< span class ="n "> step_sizes</ span > < span class ="o "> =</ span > < span class ="n "> rangebounds</ span > < span class ="o "> *</ span > < span class ="n "> step_size</ span >
180
- < span class ="n "> step</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> empty</ span > < span class ="p "> (</ span > < span class ="n "> dim</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dtype</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> uniform_</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="n "> step_sizes</ span >
184
+ < span class ="n "> step</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> empty</ span > < span class ="p "> (</ span > < span class ="n "> dim</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span >
185
+ < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dtype</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> uniform_</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="n "> step_sizes</ span >
181
186
< span class ="n "> new_u</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> u</ span > < span class ="o "> +</ span > < span class ="n "> step</ span > < span class ="o "> -</ span > < span class ="n "> bounds</ span > < span class ="p "> [:,</ span > < span class ="mi "> 0</ span > < span class ="p "> ])</ span > < span class ="o "> %</ span > < span class ="n "> rangebounds</ span > < span class ="o "> +</ span > < span class ="n "> bounds</ span > < span class ="p "> [:,</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
182
187
< span class ="k "> return</ span > < span class ="n "> new_u</ span > </ div >
183
188
@@ -254,7 +259,8 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
254
259
< span class ="p "> )</ span >
255
260
< span class ="n "> type_fval</ span > < span class ="o "> =</ span > < span class ="n "> current_fval</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span >
256
261
257
- < span class ="n "> current_weight</ span > < span class ="o "> =</ span > < span class ="n "> mix_rate</ span > < span class ="o "> /</ span > < span class ="n "> current_jac</ span > < span class ="o "> +</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="o "> -</ span > < span class ="n "> mix_rate</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="n "> current_fval</ span > < span class ="o "> .</ span > < span class ="n "> abs</ span > < span class ="p "> ()</ span >
262
+ < span class ="n "> current_weight</ span > < span class ="o "> =</ span > < span class ="n "> mix_rate</ span > < span class ="o "> /</ span > < span class ="n "> current_jac</ span > < span class ="o "> +</ span > \
263
+ < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="o "> -</ span > < span class ="n "> mix_rate</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="n "> current_fval</ span > < span class ="o "> .</ span > < span class ="n "> abs</ span > < span class ="p "> ()</ span >
258
264
< span class ="n "> current_weight</ span > < span class ="o "> .</ span > < span class ="n "> masked_fill_</ span > < span class ="p "> (</ span > < span class ="n "> current_weight</ span > < span class ="o "> <</ span > < span class ="n "> epsilon</ span > < span class ="p "> ,</ span > < span class ="n "> epsilon</ span > < span class ="p "> )</ span >
259
265
260
266
< span class ="n "> n_meas</ span > < span class ="o "> =</ span > < span class ="n "> epoch</ span > < span class ="o "> //</ span > < span class ="n "> thinning</ span >
@@ -289,8 +295,10 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
289
295
< span class ="n "> current_y</ span > < span class ="p "> ,</ span > < span class ="n "> current_x</ span > < span class ="p "> ,</ span > < span class ="n "> current_weight</ span > < span class ="p "> ,</ span > < span class ="n "> current_jac</ span >
290
296
< span class ="p "> )</ span >
291
297
292
- < span class ="n "> values</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> ((</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span > < span class ="p "> ,</ span > < span class ="n "> f_size</ span > < span class ="p "> ),</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> type_fval</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
293
- < span class ="n "> refvalues</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> type_fval</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
298
+ < span class ="n "> values</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> ((</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span > < span class ="p "> ,</ span > < span class ="n "> f_size</ span > < span class ="p "> ),</ span >
299
+ < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> type_fval</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
300
+ < span class ="n "> refvalues</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> (</ span >
301
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> nbatch</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> type_fval</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
294
302
295
303
< span class ="k "> for</ span > < span class ="n "> imeas</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> n_meas</ span > < span class ="p "> ):</ span >
296
304
< span class ="k "> for</ span > < span class ="n "> j</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> thinning</ span > < span class ="p "> ):</ span >
0 commit comments