Skip to content

Commit 084f809

Browse files
authored
Merge pull request #126 from TuringLang/torfjelde/init-params-fix
Use _init_parmas for MCMCThreads and MCMCDistributed too
2 parents caeade2 + 4dbcb3f commit 084f809

File tree

5 files changed

+117
-50
lines changed

5 files changed

+117
-50
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "4.4.2"
6+
version = "4.6.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
@@ -30,9 +30,10 @@ Transducers = "0.4.30"
3030
julia = "1.6"
3131

3232
[extras]
33+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3334
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
3435
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3536
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3637

3738
[targets]
38-
test = ["IJulia", "Statistics", "Test"]
39+
test = ["FillArrays", "IJulia", "Statistics", "Test"]

docs/src/api.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Common keyword arguments for regular and parallel sampling are:
8282
There is no "official" way for providing initial parameter values yet.
8383
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
8484
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
85-
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`.
85+
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
8686

8787
Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.
8888

src/sample.jl

+25-29
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ function mcmcsample(
312312
# Create a seed for each chain using the provided random number generator.
313313
seeds = rand(rng, UInt, nchains)
314314

315-
# Ensure that initial parameters are `nothing` or indexable
316-
_init_params = _first_or_nothing(init_params, nchains)
315+
# Ensure that initial parameters are `nothing` or of the correct length
316+
check_initial_params(init_params, nchains)
317317

318318
# Set up a chains vector.
319319
chains = Vector{Any}(undef, nchains)
@@ -364,10 +364,10 @@ function mcmcsample(
364364
_sampler,
365365
N;
366366
progress=false,
367-
init_params=if _init_params === nothing
367+
init_params=if init_params === nothing
368368
nothing
369369
else
370-
_init_params[chainidx]
370+
init_params[chainidx]
371371
end,
372372
kwargs...,
373373
)
@@ -410,6 +410,9 @@ function mcmcsample(
410410
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
411411
end
412412

413+
# Ensure that initial parameters are `nothing` or of the correct length
414+
check_initial_params(init_params, nchains)
415+
413416
# Create a seed for each chain using the provided random number generator.
414417
seeds = rand(rng, UInt, nchains)
415418

@@ -499,6 +502,9 @@ function mcmcsample(
499502
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
500503
end
501504

505+
# Ensure that initial parameters are `nothing` or of the correct length
506+
check_initial_params(init_params, nchains)
507+
502508
# Create a seed for each chain using the provided random number generator.
503509
seeds = rand(rng, UInt, nchains)
504510

@@ -532,31 +538,21 @@ end
532538
tighten_eltype(x) = x
533539
tighten_eltype(x::Vector{Any}) = map(identity, x)
534540

535-
"""
536-
_first_or_nothing(x, n::Int)
537-
538-
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
539-
540-
If `x !== nothing`, then `x` has to contain at least `n` elements.
541-
"""
542-
function _first_or_nothing(x, n::Int)
543-
y = _first(x, n)
544-
length(y) == n || throw(
545-
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
546-
)
547-
return y
548-
end
549-
_first_or_nothing(::Nothing, ::Int) = nothing
541+
@nospecialize check_initial_params(x, n) = throw(
542+
ArgumentError(
543+
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
544+
),
545+
)
550546

551-
# `first(x, n::Int)` requires Julia 1.6
552-
function _first(x, n::Int)
553-
@static if VERSION >= v"1.6.0-DEV.431"
554-
first(x, n)
555-
else
556-
if x isa AbstractVector
557-
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
558-
else
559-
collect(Iterators.take(x, n))
560-
end
547+
check_initial_params(::Nothing, n) = nothing
548+
function check_initial_params(x::AbstractArray, n)
549+
if length(x) != n
550+
throw(
551+
ArgumentError(
552+
"incorrect number of initial parameters (expected $n, received $(length(x))"
553+
),
554+
)
561555
end
556+
557+
return nothing
562558
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using IJulia
44
using LogDensityProblems
55
using LoggingExtras: TeeLogger, EarlyFilteredLogger
66
using TerminalLoggers: TerminalLogger
7+
using FillArrays: FillArrays
78
using Transducers
89

910
using Distributed

test/sample.jl

+87-18
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,18 @@
162162
end
163163

164164
# initial parameters
165-
init_params = [(b=randn(), a=rand()) for _ in 1:100]
165+
nchains = 100
166+
init_params = [(b=randn(), a=rand()) for _ in 1:nchains]
166167
chains = sample(
167168
MyModel(),
168169
MySampler(),
169170
MCMCThreads(),
170171
3,
171-
100;
172+
nchains;
172173
progress=false,
173174
init_params=init_params,
174175
)
175-
@test length(chains) == 100
176+
@test length(chains) == nchains
176177
@test all(
177178
chain[1].a == params.a && chain[1].b == params.b for
178179
(chain, params) in zip(chains, init_params)
@@ -184,14 +185,36 @@
184185
MySampler(),
185186
MCMCThreads(),
186187
3,
187-
100;
188+
nchains;
188189
progress=false,
189-
init_params=Iterators.repeated(init_params),
190+
init_params=FillArrays.Fill(init_params, nchains),
190191
)
191-
@test length(chains) == 100
192+
@test length(chains) == nchains
192193
@test all(
193194
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
194195
)
196+
197+
# Too many `init_params`
198+
@test_throws ArgumentError sample(
199+
MyModel(),
200+
MySampler(),
201+
MCMCThreads(),
202+
3,
203+
nchains;
204+
progress=false,
205+
init_params=FillArrays.Fill(init_params, nchains + 1),
206+
)
207+
208+
# Too few `init_params`
209+
@test_throws ArgumentError sample(
210+
MyModel(),
211+
MySampler(),
212+
MCMCThreads(),
213+
3,
214+
nchains;
215+
progress=false,
216+
init_params=FillArrays.Fill(init_params, nchains - 1),
217+
)
195218
end
196219

197220
@testset "Multicore sampling" begin
@@ -274,17 +297,18 @@
274297
@test all(l.level > Logging.LogLevel(-1) for l in logs)
275298

276299
# initial parameters
277-
init_params = [(a=randn(), b=rand()) for _ in 1:100]
300+
nchains = 100
301+
init_params = [(a=randn(), b=rand()) for _ in 1:nchains]
278302
chains = sample(
279303
MyModel(),
280304
MySampler(),
281305
MCMCDistributed(),
282306
3,
283-
100;
307+
nchains;
284308
progress=false,
285309
init_params=init_params,
286310
)
287-
@test length(chains) == 100
311+
@test length(chains) == nchains
288312
@test all(
289313
chain[1].a == params.a && chain[1].b == params.b for
290314
(chain, params) in zip(chains, init_params)
@@ -296,15 +320,37 @@
296320
MySampler(),
297321
MCMCDistributed(),
298322
3,
299-
100;
323+
nchains;
300324
progress=false,
301-
init_params=Iterators.repeated(init_params),
325+
init_params=FillArrays.Fill(init_params, nchains),
302326
)
303-
@test length(chains) == 100
327+
@test length(chains) == nchains
304328
@test all(
305329
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
306330
)
307331

332+
# Too many `init_params`
333+
@test_throws ArgumentError sample(
334+
MyModel(),
335+
MySampler(),
336+
MCMCDistributed(),
337+
3,
338+
nchains;
339+
progress=false,
340+
init_params=FillArrays.Fill(init_params, nchains + 1),
341+
)
342+
343+
# Too few `init_params`
344+
@test_throws ArgumentError sample(
345+
MyModel(),
346+
MySampler(),
347+
MCMCDistributed(),
348+
3,
349+
nchains;
350+
progress=false,
351+
init_params=FillArrays.Fill(init_params, nchains - 1),
352+
)
353+
308354
# Remove workers
309355
rmprocs(pids...)
310356
end
@@ -360,17 +406,18 @@
360406
@test all(l.level > Logging.LogLevel(-1) for l in logs)
361407

362408
# initial parameters
363-
init_params = [(a=rand(), b=randn()) for _ in 1:100]
409+
nchains = 100
410+
init_params = [(a=rand(), b=randn()) for _ in 1:nchains]
364411
chains = sample(
365412
MyModel(),
366413
MySampler(),
367414
MCMCSerial(),
368415
3,
369-
100;
416+
nchains;
370417
progress=false,
371418
init_params=init_params,
372419
)
373-
@test length(chains) == 100
420+
@test length(chains) == nchains
374421
@test all(
375422
chain[1].a == params.a && chain[1].b == params.b for
376423
(chain, params) in zip(chains, init_params)
@@ -382,14 +429,36 @@
382429
MySampler(),
383430
MCMCSerial(),
384431
3,
385-
100;
432+
nchains;
386433
progress=false,
387-
init_params=Iterators.repeated(init_params),
434+
init_params=FillArrays.Fill(init_params, nchains),
388435
)
389-
@test length(chains) == 100
436+
@test length(chains) == nchains
390437
@test all(
391438
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
392439
)
440+
441+
# Too many `init_params`
442+
@test_throws ArgumentError sample(
443+
MyModel(),
444+
MySampler(),
445+
MCMCSerial(),
446+
3,
447+
nchains;
448+
progress=false,
449+
init_params=FillArrays.Fill(init_params, nchains + 1),
450+
)
451+
452+
# Too few `init_params`
453+
@test_throws ArgumentError sample(
454+
MyModel(),
455+
MySampler(),
456+
MCMCSerial(),
457+
3,
458+
nchains;
459+
progress=false,
460+
init_params=FillArrays.Fill(init_params, nchains - 1),
461+
)
393462
end
394463

395464
@testset "Ensemble sampling: Reproducibility" begin

0 commit comments

Comments
 (0)