Skip to content

Commit a18c12f

Browse files
torfjeldegithub-actions[bot]devmotion
authored
Addition of step_warmup (#117)
* added step_warmup which is can be overloaded when convenient * added step_warmup to docs * Update src/interface.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * introduce new kwarg `num_warmup` to `sample` which uses `step_warmup` * updated docs * allow combination of discard_initial and num_warmup * added docstring for mcmcsample * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Update src/sample.jl Co-authored-by: David Widmann <[email protected]> * removed docstring and deferred description of keyword arguments to the docs * Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added num_warmup to common keyword arguments docs * also allow step_warmup for the initial step * simplify logic for discarding fffinitial samples * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * also report progress for the discarded samples * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * move progress-report to end of for-loop for discard samples * move step_warmup to the inner while loops too * Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * reverted to for-loop * Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added accidentanly removed comment * Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed formatting * fix typo * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Added testing of warmup steps * Added checks as @devmotion requested * Removed unintended change in previous commit * Bumped patch version * Bump minor version instead of patch version since this is a new feature --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
1 parent 8431b31 commit a18c12f

File tree

7 files changed

+193
-34
lines changed

7 files changed

+193
-34
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.3.0"
6+
version = "5.4.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

docs/src/api.md

+7-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,13 @@ Common keyword arguments for regular and parallel sampling are:
7171
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
7272
- `chain_type` (default: `Any`): determines the type of the returned chain
7373
- `callback` (default: `nothing`): if `callback !== nothing`, then
74-
`callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step,
75-
where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler
76-
- `discard_initial` (default: `0`): number of initial samples that are discarded
74+
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
75+
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
76+
- `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step,
77+
i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to
78+
[`AbstractMCMC.step`](@ref).
79+
- `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that
80+
if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples.
7781
- `thinning` (default: `1`): factor by which to thin samples.
7882
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
7983
is passed `initial_state` as the `state` argument.

docs/src/design.md

+9
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ the sampling step of the inference method.
6363
AbstractMCMC.step
6464
```
6565

66+
If one also has some special handling of the warmup-stage of sampling, then this can be specified by overloading
67+
68+
```@docs
69+
AbstractMCMC.step_warmup
70+
```
71+
72+
which will be used for the first `num_warmup` iterations, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref).
73+
Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above.
74+
6675
## Collecting samples
6776

6877
!!! note

src/interface.jl

+17
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@ current `state` of the sampler.
7373
"""
7474
function step end
7575

76+
"""
77+
step_warmup(rng, model, sampler[, state; kwargs...])
78+
79+
Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`.
80+
81+
When sampling using [`sample`](@ref), this takes the place of [`AbstractMCMC.step`](@ref) in the first
82+
`num_warmup` number of iterations, as specified by the `num_warmup` keyword to [`sample`](@ref).
83+
This is useful if the sampler has an initial "warmup"-stage that is different from the
84+
standard iteration.
85+
86+
By default, this simply calls [`AbstractMCMC.step`](@ref).
87+
"""
88+
step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...)
89+
function step_warmup(rng, model, sampler, state; kwargs...)
90+
return step(rng, model, sampler, state; kwargs...)
91+
end
92+
7693
"""
7794
samples(sample, model, sampler[, N; kwargs...])
7895

src/sample.jl

+102-30
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ isdone(rng, model, sampler, samples, state, iteration; kwargs...)
4343
```
4444
where `state` and `iteration` are the current state and iteration of the sampler, respectively.
4545
It should return `true` when sampling should end, and `false` otherwise.
46+
47+
# Keyword arguments
48+
49+
See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword
50+
arguments.
4651
"""
4752
function StatsBase.sample(
4853
rng::Random.AbstractRNG,
@@ -80,6 +85,11 @@ end
8085
8186
Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel
8287
using the `parallel` algorithm, and combine them into a single chain.
88+
89+
# Keyword arguments
90+
91+
See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword
92+
arguments.
8393
"""
8494
function StatsBase.sample(
8595
rng::Random.AbstractRNG,
@@ -94,7 +104,6 @@ function StatsBase.sample(
94104
end
95105

96106
# Default implementations of regular and parallel sampling.
97-
98107
function mcmcsample(
99108
rng::Random.AbstractRNG,
100109
model::AbstractModel,
@@ -103,15 +112,28 @@ function mcmcsample(
103112
progress=PROGRESS[],
104113
progressname="Sampling",
105114
callback=nothing,
106-
discard_initial=0,
115+
num_warmup::Int=0,
116+
discard_initial::Int=num_warmup,
107117
thinning=1,
108118
chain_type::Type=Any,
109119
initial_state=nothing,
110120
kwargs...,
111121
)
112122
# Check the number of requested samples.
113123
N > 0 || error("the number of samples must be ≥ 1")
124+
discard_initial >= 0 ||
125+
throw(ArgumentError("number of discarded samples must be non-negative"))
126+
num_warmup >= 0 ||
127+
throw(ArgumentError("number of warm-up samples must be non-negative"))
114128
Ntotal = thinning * (N - 1) + discard_initial + 1
129+
Ntotal >= num_warmup || throw(
130+
ArgumentError("number of warm-up samples exceeds the total number of samples")
131+
)
132+
133+
# Determine how many samples to drop from `num_warmup` and the
134+
# main sampling process before we start saving samples.
135+
discard_from_warmup = min(num_warmup, discard_initial)
136+
keep_from_warmup = num_warmup - discard_from_warmup
115137

116138
# Start the timer
117139
start = time()
@@ -126,22 +148,41 @@ function mcmcsample(
126148
end
127149

128150
# Obtain the initial sample and state.
129-
sample, state = if initial_state === nothing
130-
step(rng, model, sampler; kwargs...)
151+
sample, state = if num_warmup > 0
152+
if initial_state === nothing
153+
step_warmup(rng, model, sampler; kwargs...)
154+
else
155+
step_warmup(rng, model, sampler, initial_state; kwargs...)
156+
end
131157
else
132-
step(rng, model, sampler, initial_state; kwargs...)
158+
if initial_state === nothing
159+
step(rng, model, sampler; kwargs...)
160+
else
161+
step(rng, model, sampler, initial_state; kwargs...)
162+
end
163+
end
164+
165+
# Update the progress bar.
166+
itotal = 1
167+
if progress && itotal >= next_update
168+
ProgressLogging.@logprogress itotal / Ntotal
169+
next_update = itotal + threshold
133170
end
134171

135172
# Discard initial samples.
136-
for i in 1:discard_initial
137-
# Update the progress bar.
138-
if progress && i >= next_update
139-
ProgressLogging.@logprogress i / Ntotal
140-
next_update = i + threshold
173+
for j in 1:discard_initial
174+
# Obtain the next sample and state.
175+
sample, state = if j num_warmup
176+
step_warmup(rng, model, sampler, state; kwargs...)
177+
else
178+
step(rng, model, sampler, state; kwargs...)
141179
end
142180

143-
# Obtain the next sample and state.
144-
sample, state = step(rng, model, sampler, state; kwargs...)
181+
# Update the progress bar.
182+
if progress && (itotal += 1) >= next_update
183+
ProgressLogging.@logprogress itotal / Ntotal
184+
next_update = itotal + threshold
185+
end
145186
end
146187

147188
# Run callback.
@@ -151,19 +192,16 @@ function mcmcsample(
151192
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
152193
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)
153194

154-
# Update the progress bar.
155-
itotal = 1 + discard_initial
156-
if progress && itotal >= next_update
157-
ProgressLogging.@logprogress itotal / Ntotal
158-
next_update = itotal + threshold
159-
end
160-
161195
# Step through the sampler.
162196
for i in 2:N
163197
# Discard thinned samples.
164198
for _ in 1:(thinning - 1)
165199
# Obtain the next sample and state.
166-
sample, state = step(rng, model, sampler, state; kwargs...)
200+
sample, state = if i keep_from_warmup
201+
step_warmup(rng, model, sampler, state; kwargs...)
202+
else
203+
step(rng, model, sampler, state; kwargs...)
204+
end
167205

168206
# Update progress bar.
169207
if progress && (itotal += 1) >= next_update
@@ -173,7 +211,11 @@ function mcmcsample(
173211
end
174212

175213
# Obtain the next sample and state.
176-
sample, state = step(rng, model, sampler, state; kwargs...)
214+
sample, state = if i keep_from_warmup
215+
step_warmup(rng, model, sampler, state; kwargs...)
216+
else
217+
step(rng, model, sampler, state; kwargs...)
218+
end
177219

178220
# Run callback.
179221
callback === nothing ||
@@ -217,28 +259,51 @@ function mcmcsample(
217259
progress=PROGRESS[],
218260
progressname="Convergence sampling",
219261
callback=nothing,
220-
discard_initial=0,
262+
num_warmup=0,
263+
discard_initial=num_warmup,
221264
thinning=1,
222265
initial_state=nothing,
223266
kwargs...,
224267
)
268+
# Check the number of requested samples.
269+
discard_initial >= 0 ||
270+
throw(ArgumentError("number of discarded samples must be non-negative"))
271+
num_warmup >= 0 ||
272+
throw(ArgumentError("number of warm-up samples must be non-negative"))
273+
274+
# Determine how many samples to drop from `num_warmup` and the
275+
# main sampling process before we start saving samples.
276+
discard_from_warmup = min(num_warmup, discard_initial)
277+
keep_from_warmup = num_warmup - discard_from_warmup
225278

226279
# Start the timer
227280
start = time()
228281
local state
229282

230283
@ifwithprogresslogger progress name = progressname begin
231284
# Obtain the initial sample and state.
232-
sample, state = if initial_state === nothing
233-
step(rng, model, sampler; kwargs...)
285+
sample, state = if num_warmup > 0
286+
if initial_state === nothing
287+
step_warmup(rng, model, sampler; kwargs...)
288+
else
289+
step_warmup(rng, model, sampler, initial_state; kwargs...)
290+
end
234291
else
235-
step(rng, model, sampler, state; kwargs...)
292+
if initial_state === nothing
293+
step(rng, model, sampler; kwargs...)
294+
else
295+
step(rng, model, sampler, initial_state; kwargs...)
296+
end
236297
end
237298

238299
# Discard initial samples.
239-
for _ in 1:discard_initial
300+
for j in 1:discard_initial
240301
# Obtain the next sample and state.
241-
sample, state = step(rng, model, sampler, state; kwargs...)
302+
sample, state = if j num_warmup
303+
step_warmup(rng, model, sampler, state; kwargs...)
304+
else
305+
step(rng, model, sampler, state; kwargs...)
306+
end
242307
end
243308

244309
# Run callback.
@@ -250,16 +315,23 @@ function mcmcsample(
250315

251316
# Step through the sampler until stopping.
252317
i = 2
253-
254318
while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...)
255319
# Discard thinned samples.
256320
for _ in 1:(thinning - 1)
257321
# Obtain the next sample and state.
258-
sample, state = step(rng, model, sampler, state; kwargs...)
322+
sample, state = if i keep_from_warmup
323+
step_warmup(rng, model, sampler, state; kwargs...)
324+
else
325+
step(rng, model, sampler, state; kwargs...)
326+
end
259327
end
260328

261329
# Obtain the next sample and state.
262-
sample, state = step(rng, model, sampler, state; kwargs...)
330+
sample, state = if i keep_from_warmup
331+
step_warmup(rng, model, sampler, state; kwargs...)
332+
else
333+
step(rng, model, sampler, state; kwargs...)
334+
end
263335

264336
# Run callback.
265337
callback === nothing ||

test/sample.jl

+39
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,45 @@
575575
@test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N)
576576
end
577577

578+
@testset "Warm-up steps" begin
579+
# Create a chain and discard initial samples.
580+
Random.seed!(1234)
581+
N = 100
582+
num_warmup = 50
583+
584+
# Everything should be discarded here.
585+
chain = sample(MyModel(), MySampler(), N; num_warmup=num_warmup)
586+
@test length(chain) == N
587+
@test !ismissing(chain[1].a)
588+
589+
# Repeat sampling without discarding initial samples.
590+
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
591+
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
592+
Random.seed!(1234)
593+
ref_chain = sample(
594+
MyModel(), MySampler(), N + num_warmup; progress=VERSION < v"1.6"
595+
)
596+
@test all(chain[i].a == ref_chain[i + num_warmup].a for i in 1:N)
597+
@test all(chain[i].b == ref_chain[i + num_warmup].b for i in 1:N)
598+
599+
# Some other stuff.
600+
Random.seed!(1234)
601+
discard_initial = 10
602+
chain_warmup = sample(
603+
MyModel(),
604+
MySampler(),
605+
N;
606+
num_warmup=num_warmup,
607+
discard_initial=discard_initial,
608+
)
609+
@test length(chain_warmup) == N
610+
@test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N)
611+
# Check that the first `num_warmup - discard_initial` samples are warmup samples.
612+
@test all(
613+
chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N
614+
)
615+
end
616+
578617
@testset "Thin chain by a factor of `thinning`" begin
579618
# Run a thinned chain with `N` samples thinned by factor of `thinning`.
580619
Random.seed!(100)

test/utils.jl

+18
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ struct MyModel <: AbstractMCMC.AbstractModel end
33
struct MySample{A,B}
44
a::A
55
b::B
6+
is_warmup::Bool
67
end
78

9+
MySample(a, b) = MySample(a, b, false)
10+
811
struct MySampler <: AbstractMCMC.AbstractSampler end
912
struct AnotherSampler <: AbstractMCMC.AbstractSampler end
1013

@@ -16,6 +19,21 @@ end
1619

1720
MyChain(a, b) = MyChain(a, b, NamedTuple())
1821

22+
function AbstractMCMC.step_warmup(
23+
rng::AbstractRNG,
24+
model::MyModel,
25+
sampler::MySampler,
26+
state::Union{Nothing,Integer}=nothing;
27+
loggers=false,
28+
initial_params=nothing,
29+
kwargs...,
30+
)
31+
transition, state = AbstractMCMC.step(
32+
rng, model, sampler, state; loggers, initial_params, kwargs...
33+
)
34+
return MySample(transition.a, transition.b, true), state
35+
end
36+
1937
function AbstractMCMC.step(
2038
rng::AbstractRNG,
2139
model::MyModel,

0 commit comments

Comments
 (0)