Skip to content

Commit 1324ceb

Browse files
committed
deploy: 39debfc
1 parent fbc189e commit 1324ceb

File tree

3 files changed

+45
-25
lines changed

3 files changed

+45
-25
lines changed

_modules/integrators.html

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
7272
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
7373
<span class="k">if</span> <span class="n">maps</span><span class="p">:</span>
7474
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">maps</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
75-
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Float type of maps should be same as integrator.&quot;</span><span class="p">)</span>
75+
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
76+
<span class="s2">&quot;Float type of maps should be same as integrator.&quot;</span><span class="p">)</span>
7677
<span class="bp">self</span><span class="o">.</span><span class="n">bounds</span> <span class="o">=</span> <span class="n">maps</span><span class="o">.</span><span class="n">bounds</span>
7778
<span class="k">else</span><span class="p">:</span>
7879
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">bounds</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)):</span>
@@ -123,6 +124,7 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
123124
<span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span class="p">,</span>
124125
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float64</span><span class="p">,</span>
125126
<span class="p">):</span>
127+
126128
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">maps</span><span class="p">,</span> <span class="n">bounds</span><span class="p">,</span> <span class="n">q0</span><span class="p">,</span> <span class="n">neval</span><span class="p">,</span> <span class="n">nbatch</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span>
127129

128130
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">f</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
@@ -142,12 +144,14 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
142144
<span class="p">)</span>
143145

144146
<span class="n">epoch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">neval</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span>
145-
<span class="n">values</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span><span class="p">,</span> <span class="n">f_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">type_fval</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
147+
<span class="n">values</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span><span class="p">,</span> <span class="n">f_size</span><span class="p">),</span>
148+
<span class="n">dtype</span><span class="o">=</span><span class="n">type_fval</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
146149

147150
<span class="k">for</span> <span class="n">iepoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epoch</span><span class="p">):</span>
148151
<span class="n">x</span><span class="p">,</span> <span class="n">log_detJ</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span><span class="p">)</span>
149152
<span class="n">f_values</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
150-
<span class="n">batch_results</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_multiply_by_jacobian</span><span class="p">(</span><span class="n">f_values</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_detJ</span><span class="p">))</span>
153+
<span class="n">batch_results</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_multiply_by_jacobian</span><span class="p">(</span>
154+
<span class="n">f_values</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_detJ</span><span class="p">))</span>
151155

152156
<span class="n">values</span> <span class="o">+=</span> <span class="n">batch_results</span> <span class="o">/</span> <span class="n">epoch</span>
153157

@@ -177,7 +181,8 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
177181
<span class="n">rangebounds</span> <span class="o">=</span> <span class="n">bounds</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">bounds</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
178182
<span class="n">step_size</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;step_size&quot;</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">)</span>
179183
<span class="n">step_sizes</span> <span class="o">=</span> <span class="n">rangebounds</span> <span class="o">*</span> <span class="n">step_size</span>
180-
<span class="n">step</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">step_sizes</span>
184+
<span class="n">step</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
185+
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">step_sizes</span>
181186
<span class="n">new_u</span> <span class="o">=</span> <span class="p">(</span><span class="n">u</span> <span class="o">+</span> <span class="n">step</span> <span class="o">-</span> <span class="n">bounds</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="o">%</span> <span class="n">rangebounds</span> <span class="o">+</span> <span class="n">bounds</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
182187
<span class="k">return</span> <span class="n">new_u</span></div>
183188

@@ -254,7 +259,8 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
254259
<span class="p">)</span>
255260
<span class="n">type_fval</span> <span class="o">=</span> <span class="n">current_fval</span><span class="o">.</span><span class="n">dtype</span>
256261

257-
<span class="n">current_weight</span> <span class="o">=</span> <span class="n">mix_rate</span> <span class="o">/</span> <span class="n">current_jac</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mix_rate</span><span class="p">)</span> <span class="o">*</span> <span class="n">current_fval</span><span class="o">.</span><span class="n">abs</span><span class="p">()</span>
262+
<span class="n">current_weight</span> <span class="o">=</span> <span class="n">mix_rate</span> <span class="o">/</span> <span class="n">current_jac</span> <span class="o">+</span> \
263+
<span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mix_rate</span><span class="p">)</span> <span class="o">*</span> <span class="n">current_fval</span><span class="o">.</span><span class="n">abs</span><span class="p">()</span>
258264
<span class="n">current_weight</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">current_weight</span> <span class="o">&lt;</span> <span class="n">epsilon</span><span class="p">,</span> <span class="n">epsilon</span><span class="p">)</span>
259265

260266
<span class="n">n_meas</span> <span class="o">=</span> <span class="n">epoch</span> <span class="o">//</span> <span class="n">thinning</span>
@@ -289,8 +295,10 @@ <h1>Source code for integrators</h1><div class="highlight"><pre>
289295
<span class="n">current_y</span><span class="p">,</span> <span class="n">current_x</span><span class="p">,</span> <span class="n">current_weight</span><span class="p">,</span> <span class="n">current_jac</span>
290296
<span class="p">)</span>
291297

292-
<span class="n">values</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span><span class="p">,</span> <span class="n">f_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">type_fval</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
293-
<span class="n">refvalues</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">type_fval</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
298+
<span class="n">values</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span><span class="p">,</span> <span class="n">f_size</span><span class="p">),</span>
299+
<span class="n">dtype</span><span class="o">=</span><span class="n">type_fval</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
300+
<span class="n">refvalues</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
301+
<span class="bp">self</span><span class="o">.</span><span class="n">nbatch</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">type_fval</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
294302

295303
<span class="k">for</span> <span class="n">imeas</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_meas</span><span class="p">):</span>
296304
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">thinning</span><span class="p">):</span>

0 commit comments

Comments
 (0)