Skip to content

Commit f924e17

Browse files
Merge pull request #58 from numericalEFT/zyli
Add Parameter `thermal_ratio` to adjust the thermalization steps for the MCMC method.
2 parents e561c7c + 66e5025 commit f924e17

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/main.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
block=16,
88
measure::Union{Nothing,Function}=nothing,
99
measurefreq::Int=1,
10+
nburnin::Int=100,
1011
inplace::Bool=false,
1112
adapt=true,
1213
gamma=1.0,
@@ -35,6 +36,7 @@ Calculate the integrals, collect statistics, and return a `Result` struct contai
3536
- For `solver = :vegas` or `:vegasmc`, the function signature should be `measure(var, obs, relative_weights, config)`. Here, `obs` is a vector of observable values for each component of the integrand and `relative_weights` are the weights calculated from the integrand multiplied by the probability of the corresponding variables.
3637
- For `solver = :mcmc`, the signature should be `measure(idx, var, obs, relative_weight, config)`, where `obs` is the observable vector and `relative_weight` is the weight calculated from the `idx`-th integrand multiplied by the probability of the variables.
3738
- `measurefreq`: How often the measurement function is called (default: `1`).
39+
- `nburnin` : Tha thermalization steps for MCMC method
3840
- `inplace`: Whether to use the inplace version of the integrand. Default is `false`, which is more convenient for integrand with a few return values but may cause type instability. Only useful for the :vegas and :vegasmc solver.
3941
- `adapt`: Whether to adapt the grid and the reweight factor (default: `true`).
4042
- `gamma`: Learning rate of the reweight factor after each iteration (default: `1.0`).
@@ -80,6 +82,7 @@ function integrate(integrand::Function;
8082
ignore::Int=adapt ? 1 : 0, #ignore the first `ignore` iterations in average
8183
measure::Union{Nothing,Function}=nothing,
8284
measurefreq::Int=1,
85+
nburnin::Int = 100,
8386
inplace::Bool=false, # whether to use the inplace version of the integrand
8487
parallel::Symbol=:nothread, # :thread or :nothread
8588
print=-1, printio=stdout, timer=[],
@@ -151,13 +154,13 @@ function integrate(integrand::Function;
151154
Threads.@threads for _ in 1:block/MCUtility.mpi_nprocs()
152155
_block!(configs, obsSum, obsSquaredSum, summedConfig, solver, progress,
153156
integrand, nevalperblock, print, timer, debug,
154-
measure, measurefreq, inplace, parallel)
157+
measure, measurefreq, nburnin, inplace, parallel)
155158
end
156159
else
157160
for _ in 1:block/MCUtility.mpi_nprocs()
158161
_block!(configs, obsSum, obsSquaredSum, summedConfig, solver, progress,
159162
integrand, nevalperblock, print, timer, debug,
160-
measure, measurefreq, inplace, parallel)
163+
measure, measurefreq, nburnin, inplace, parallel)
161164
end
162165
end
163166
end
@@ -233,7 +236,7 @@ end
233236
function _block!(configs, obsSum, obsSquaredSum, summedConfig,
234237
solver, progress,
235238
integrand::Function, nevalperblock, print, timer, debug::Bool,
236-
measure::Union{Nothing,Function}, measurefreq, inplace, parallel)
239+
measure::Union{Nothing,Function}, measurefreq, nburnin, inplace, parallel)
237240

238241
rank = MCUtility.threadid(parallel)
239242
# println(rank)
@@ -249,7 +252,7 @@ function _block!(configs, obsSum, obsSquaredSum, summedConfig,
249252
measure=measure, measurefreq=measurefreq, inplace=inplace)
250253
elseif solver == :mcmc
251254
MCMC.montecarlo(config_n, integrand, nevalperblock, print, timer, debug;
252-
measure=measure, measurefreq=measurefreq)
255+
measure=measure, measurefreq=measurefreq, nburnin = nburnin)
253256
else
254257
error("Solver $solver is not supported!")
255258
end

src/mcmc/montecarlo.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ function montecarlo(config::Configuration{N,V,P,O,T}, integrand::Function, neval
7373
verbose=0, timer=[], debug=false;
7474
measurefreq::Int=1,
7575
measure::Union{Nothing,Function}=nothing,
76-
idx::Int=1 # the integral to start with
76+
idx::Int=1, # the integral to start with
77+
nburnin::Int=100
7778
) where {N,V,P,O,T}
7879

7980
@assert measurefreq > 0
@@ -130,7 +131,7 @@ function montecarlo(config::Configuration{N,V,P,O,T}, integrand::Function, neval
130131
# end
131132
startTime = time()
132133

133-
for i = 1:neval
134+
for i = 1:(neval+nburnin)
134135
# config.neval += 1
135136
config.visited[state.curr] += 1
136137
_update = rand(config.rng, updates) # randomly select an update
@@ -140,7 +141,7 @@ function montecarlo(config::Configuration{N,V,P,O,T}, integrand::Function, neval
140141
if debug && (isfinite(state.probability) == false)
141142
@warn("integrand probability = $(state.probability) is not finite at step $(config.neval)")
142143
end
143-
if i % measurefreq == 0 && i >= neval / 100
144+
if i % measurefreq == 0 && i >= nburnin
144145

145146
######## accumulate variable #################
146147
if state.curr != config.norm

0 commit comments

Comments
 (0)