Skip to content

Commit 3de7393

Browse files
authored
Support init_params in ensemble methods (#94)
* Support `init_params` in ensemble methods * Fix typo * Fix typo * Add documentation * Support `Iterators.Repeated` * Breaking release * Fix and simplify docs setup * Remove deprecations * Reduce tasks on Windows * Generalize to arbitrary collections * Use Blue style
1 parent 4994a79 commit 3de7393

File tree

9 files changed

+187
-17
lines changed

9 files changed

+187
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "3.3.1"
6+
version = "4.0.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

docs/src/api.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ are:
5353
- `discard_initial` (default: `0`): number of initial samples that are discarded
5454
- `thinning` (default: `1`): factor by which to thin samples.
5555

56+
There is no "official" way for providing initial parameter values yet.
57+
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
58+
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
59+
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`.
60+
5661
Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.
5762

5863
```@docs

src/AbstractMCMC.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,5 @@ include("interface.jl")
8484
include("sample.jl")
8585
include("stepper.jl")
8686
include("transducer.jl")
87-
include("deprecations.jl")
8887

8988
end # module AbstractMCMC

src/deprecations.jl

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/sample.jl

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ function mcmcsample(
283283
nchains::Integer;
284284
progress=PROGRESS[],
285285
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
286+
init_params=nothing,
286287
kwargs...,
287288
)
288289
# Check if actually multiple threads are used.
@@ -298,14 +299,17 @@ function mcmcsample(
298299
# Copy the random number generator, model, and sample for each thread
299300
nchunks = min(nchains, Threads.nthreads())
300301
chunksize = cld(nchains, nchunks)
301-
interval = 1:min(nchains, Threads.nthreads())
302+
interval = 1:nchunks
302303
rngs = [deepcopy(rng) for _ in interval]
303304
models = [deepcopy(model) for _ in interval]
304305
samplers = [deepcopy(sampler) for _ in interval]
305306

306307
# Create a seed for each chain using the provided random number generator.
307308
seeds = rand(rng, UInt, nchains)
308309

310+
# Ensure that initial parameters are `nothing` or indexable
311+
_init_params = _first_or_nothing(init_params, nchains)
312+
309313
# Set up a chains vector.
310314
chains = Vector{Any}(undef, nchains)
311315

@@ -350,7 +354,17 @@ function mcmcsample(
350354

351355
# Sample a chain and save it to the vector.
352356
chains[chainidx] = StatsBase.sample(
353-
_rng, _model, _sampler, N; progress=false, kwargs...
357+
_rng,
358+
_model,
359+
_sampler,
360+
N;
361+
progress=false,
362+
init_params=if _init_params === nothing
363+
nothing
364+
else
365+
_init_params[chainidx]
366+
end,
367+
kwargs...,
354368
)
355369

356370
# Update the progress bar.
@@ -378,6 +392,7 @@ function mcmcsample(
378392
nchains::Integer;
379393
progress=PROGRESS[],
380394
progressname="Sampling ($(Distributed.nworkers()) processes)",
395+
init_params=nothing,
381396
kwargs...,
382397
)
383398
# Check if actually multiple processes are used.
@@ -425,13 +440,19 @@ function mcmcsample(
425440

426441
Distributed.@async begin
427442
try
428-
chains = Distributed.pmap(pool, seeds) do seed
443+
function sample_chain(seed, init_params=nothing)
429444
# Seed a new random number generator with the pre-made seed.
430445
Random.seed!(rng, seed)
431446

432447
# Sample a chain.
433448
chain = StatsBase.sample(
434-
rng, model, sampler, N; progress=false, kwargs...
449+
rng,
450+
model,
451+
sampler,
452+
N;
453+
progress=false,
454+
init_params=init_params,
455+
kwargs...,
435456
)
436457

437458
# Update the progress bar.
@@ -440,6 +461,11 @@ function mcmcsample(
440461
# Return the new chain.
441462
return chain
442463
end
464+
chains = if init_params === nothing
465+
Distributed.pmap(sample_chain, pool, seeds)
466+
else
467+
Distributed.pmap(sample_chain, pool, seeds, init_params)
468+
end
443469
finally
444470
# Stop updating the progress bar.
445471
progress && put!(channel, false)
@@ -460,6 +486,7 @@ function mcmcsample(
460486
N::Integer,
461487
nchains::Integer;
462488
progressname="Sampling",
489+
init_params=nothing,
463490
kwargs...,
464491
)
465492
# Check if the number of chains is larger than the number of samples
@@ -471,21 +498,60 @@ function mcmcsample(
471498
seeds = rand(rng, UInt, nchains)
472499

473500
# Sample the chains.
474-
chains = map(enumerate(seeds)) do (i, seed)
501+
function sample_chain(i, seed, init_params=nothing)
502+
# Seed a new random number generator with the pre-made seed.
475503
Random.seed!(rng, seed)
504+
505+
# Sample a chain.
476506
return StatsBase.sample(
477507
rng,
478508
model,
479509
sampler,
480510
N;
481511
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
512+
init_params=init_params,
482513
kwargs...,
483514
)
484515
end
485516

517+
chains = if init_params === nothing
518+
map(sample_chain, 1:nchains, seeds)
519+
else
520+
map(sample_chain, 1:nchains, seeds, init_params)
521+
end
522+
486523
# Concatenate the chains together.
487524
return chainsstack(tighten_eltype(chains))
488525
end
489526

490527
tighten_eltype(x) = x
491528
tighten_eltype(x::Vector{Any}) = map(identity, x)
529+
530+
"""
531+
_first_or_nothing(x, n::Int)
532+
533+
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
534+
535+
If `x !== nothing`, then `x` has to contain at least `n` elements.
536+
"""
537+
function _first_or_nothing(x, n::Int)
538+
y = _first(x, n)
539+
length(y) == n || throw(
540+
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
541+
)
542+
return y
543+
end
544+
_first_or_nothing(::Nothing, ::Int) = nothing
545+
546+
# `first(x, n::Int)` requires Julia 1.6
547+
function _first(x, n::Int)
548+
@static if VERSION >= v"1.6.0-DEV.431"
549+
first(x, n)
550+
else
551+
if x isa AbstractVector
552+
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
553+
else
554+
collect(Iterators.take(x, n))
555+
end
556+
end
557+
end

test/deprecations.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,4 @@ include("utils.jl")
2222
include("sample.jl")
2323
include("stepper.jl")
2424
include("transducer.jl")
25-
include("deprecations.jl")
2625
end

test/sample.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
@test var(x.a for x in tail_chain) 1 / 12 atol = 5e-3
2626
@test mean(x.b for x in tail_chain) 0.0 atol = 5e-2
2727
@test var(x.b for x in tail_chain) 1 atol = 6e-2
28+
29+
# initial parameters
30+
chain = sample(
31+
MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8)
32+
)
33+
@test chain[1].a == -1.8
34+
@test chain[1].b == 3.2
2835
end
2936

3037
@testset "Juno" begin
@@ -168,6 +175,38 @@
168175
if Threads.nthreads() == 2
169176
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
170177
end
178+
179+
# initial parameters
180+
init_params = [(b=randn(), a=rand()) for _ in 1:100]
181+
chains = sample(
182+
MyModel(),
183+
MySampler(),
184+
MCMCThreads(),
185+
3,
186+
100;
187+
progress=false,
188+
init_params=init_params,
189+
)
190+
@test length(chains) == 100
191+
@test all(
192+
chain[1].a == params.a && chain[1].b == params.b for
193+
(chain, params) in zip(chains, init_params)
194+
)
195+
196+
init_params = (a=randn(), b=rand())
197+
chains = sample(
198+
MyModel(),
199+
MySampler(),
200+
MCMCThreads(),
201+
3,
202+
100;
203+
progress=false,
204+
init_params=Iterators.repeated(init_params),
205+
)
206+
@test length(chains) == 100
207+
@test all(
208+
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
209+
)
171210
end
172211

173212
@testset "Multicore sampling" begin
@@ -244,6 +283,38 @@
244283
)
245284
end
246285
@test all(l.level > Logging.LogLevel(-1) for l in logs)
286+
287+
# initial parameters
288+
init_params = [(a=randn(), b=rand()) for _ in 1:100]
289+
chains = sample(
290+
MyModel(),
291+
MySampler(),
292+
MCMCDistributed(),
293+
3,
294+
100;
295+
progress=false,
296+
init_params=init_params,
297+
)
298+
@test length(chains) == 100
299+
@test all(
300+
chain[1].a == params.a && chain[1].b == params.b for
301+
(chain, params) in zip(chains, init_params)
302+
)
303+
304+
init_params = (b=randn(), a=rand())
305+
chains = sample(
306+
MyModel(),
307+
MySampler(),
308+
MCMCDistributed(),
309+
3,
310+
100;
311+
progress=false,
312+
init_params=Iterators.repeated(init_params),
313+
)
314+
@test length(chains) == 100
315+
@test all(
316+
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
317+
)
247318
end
248319

249320
@testset "Serial sampling" begin
@@ -295,6 +366,38 @@
295366
)
296367
end
297368
@test all(l.level > Logging.LogLevel(-1) for l in logs)
369+
370+
# initial parameters
371+
init_params = [(a=rand(), b=randn()) for _ in 1:100]
372+
chains = sample(
373+
MyModel(),
374+
MySampler(),
375+
MCMCSerial(),
376+
3,
377+
100;
378+
progress=false,
379+
init_params=init_params,
380+
)
381+
@test length(chains) == 100
382+
@test all(
383+
chain[1].a == params.a && chain[1].b == params.b for
384+
(chain, params) in zip(chains, init_params)
385+
)
386+
387+
init_params = (b=rand(), a=randn())
388+
chains = sample(
389+
MyModel(),
390+
MySampler(),
391+
MCMCSerial(),
392+
3,
393+
100;
394+
progress=false,
395+
init_params=Iterators.repeated(init_params),
396+
)
397+
@test length(chains) == 100
398+
@test all(
399+
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
400+
)
298401
end
299402

300403
@testset "Ensemble sampling: Reproducibility" begin

test/utils.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@ function AbstractMCMC.step(
2323
state::Union{Nothing,Integer}=nothing;
2424
sleepy=false,
2525
loggers=false,
26+
init_params=nothing,
2627
kwargs...,
2728
)
28-
# sample `a` is missing in the first step
29-
a = state === nothing ? missing : rand(rng)
30-
b = randn(rng)
29+
# sample `a` is missing in the first step if not provided
30+
a, b = if state === nothing && init_params !== nothing
31+
init_params.a, init_params.b
32+
else
33+
(state === nothing ? missing : rand(rng)), randn(rng)
34+
end
3135

3236
loggers && push!(LOGGERS, Logging.current_logger())
3337
sleepy && sleep(0.001)

0 commit comments

Comments
 (0)