Skip to content

Commit 8d45ff4

Browse files
authored
Merge pull request #119 from TuringLang/torfjelde/initial-state
Allow specification of initial state for `sample`
2 parents d521815 + 3ed5314 commit 8d45ff4

File tree

6 files changed

+195
-68
lines changed

6 files changed

+195
-68
lines changed

Diff for: Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ version = "4.5.0"
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1010
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
1111
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
12+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1213
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1314
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1415
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
@@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
2122
[compat]
2223
BangBang = "0.3.19"
2324
ConsoleProgressMonitor = "0.1"
25+
FillArrays = "1"
2426
LogDensityProblems = "2"
2527
LoggingExtras = "0.4, 0.5, 1"
2628
ProgressLogging = "0.1"

Diff for: docs/src/api.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,16 @@ Common keyword arguments for regular and parallel sampling are:
7575
where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler
7676
- `discard_initial` (default: `0`): number of initial samples that are discarded
7777
- `thinning` (default: `1`): factor by which to thin samples.
78+
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
79+
is passed `initial_state` as the `state` argument.
7880

7981
!!! info
8082
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).
8183

8284
There is no "official" way for providing initial parameter values yet.
83-
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
84-
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
85-
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
85+
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain.
86+
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
87+
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.
8688

8789
Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.
8890

Diff for: src/AbstractMCMC.jl

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging
88
using StatsBase: StatsBase
99
using TerminalLoggers: TerminalLoggers
1010
using Transducers: Transducers
11+
using FillArrays: FillArrays
1112

1213
using Distributed: Distributed
1314
using Logging: Logging

Diff for: src/sample.jl

+72-28
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ function mcmcsample(
103103
discard_initial=0,
104104
thinning=1,
105105
chain_type::Type=Any,
106+
initial_state=nothing,
106107
kwargs...,
107108
)
108109
# Check the number of requested samples.
@@ -122,7 +123,11 @@ function mcmcsample(
122123
end
123124

124125
# Obtain the initial sample and state.
125-
sample, state = step(rng, model, sampler; kwargs...)
126+
sample, state = if initial_state === nothing
127+
step(rng, model, sampler; kwargs...)
128+
else
129+
step(rng, model, sampler, initial_state; kwargs...)
130+
end
126131

127132
# Discard initial samples.
128133
for i in 1:discard_initial
@@ -211,6 +216,7 @@ function mcmcsample(
211216
callback=nothing,
212217
discard_initial=0,
213218
thinning=1,
219+
initial_state=nothing,
214220
kwargs...,
215221
)
216222

@@ -220,7 +226,11 @@ function mcmcsample(
220226

221227
@ifwithprogresslogger progress name = progressname begin
222228
# Obtain the initial sample and state.
223-
sample, state = step(rng, model, sampler; kwargs...)
229+
sample, state = if initial_state === nothing
230+
step(rng, model, sampler; kwargs...)
231+
else
232+
step(rng, model, sampler, state; kwargs...)
233+
end
224234

225235
# Discard initial samples.
226236
for _ in 1:discard_initial
@@ -288,7 +298,8 @@ function mcmcsample(
288298
nchains::Integer;
289299
progress=PROGRESS[],
290300
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
291-
init_params=nothing,
301+
initial_params=nothing,
302+
initial_state=nothing,
292303
kwargs...,
293304
)
294305
# Check if actually multiple threads are used.
@@ -312,8 +323,9 @@ function mcmcsample(
312323
# Create a seed for each chain using the provided random number generator.
313324
seeds = rand(rng, UInt, nchains)
314325

315-
# Ensure that initial parameters are `nothing` or of the correct length
316-
check_initial_params(init_params, nchains)
326+
# Ensure that initial parameters and states are `nothing` or of the correct length
327+
check_initial_params(initial_params, nchains)
328+
check_initial_state(initial_state, nchains)
317329

318330
# Set up a chains vector.
319331
chains = Vector{Any}(undef, nchains)
@@ -364,10 +376,15 @@ function mcmcsample(
364376
_sampler,
365377
N;
366378
progress=false,
367-
init_params=if init_params === nothing
379+
initial_params=if initial_params === nothing
380+
nothing
381+
else
382+
initial_params[chainidx]
383+
end,
384+
initial_state=if initial_state === nothing
368385
nothing
369386
else
370-
init_params[chainidx]
387+
initial_state[chainidx]
371388
end,
372389
kwargs...,
373390
)
@@ -397,7 +414,8 @@ function mcmcsample(
397414
nchains::Integer;
398415
progress=PROGRESS[],
399416
progressname="Sampling ($(Distributed.nworkers()) processes)",
400-
init_params=nothing,
417+
initial_params=nothing,
418+
initial_state=nothing,
401419
kwargs...,
402420
)
403421
# Check if actually multiple processes are used.
@@ -410,8 +428,14 @@ function mcmcsample(
410428
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
411429
end
412430

413-
# Ensure that initial parameters are `nothing` or of the correct length
414-
check_initial_params(init_params, nchains)
431+
# Ensure that initial parameters and states are `nothing` or of the correct length
432+
check_initial_params(initial_params, nchains)
433+
check_initial_state(initial_state, nchains)
434+
435+
_initial_params =
436+
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
437+
_initial_state =
438+
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
415439

416440
# Create a seed for each chain using the provided random number generator.
417441
seeds = rand(rng, UInt, nchains)
@@ -448,7 +472,7 @@ function mcmcsample(
448472

449473
Distributed.@async begin
450474
try
451-
function sample_chain(seed, init_params=nothing)
475+
function sample_chain(seed, initial_params, initial_state)
452476
# Seed a new random number generator with the pre-made seed.
453477
Random.seed!(rng, seed)
454478

@@ -459,7 +483,8 @@ function mcmcsample(
459483
sampler,
460484
N;
461485
progress=false,
462-
init_params=init_params,
486+
initial_params=initial_params,
487+
initial_state=initial_state,
463488
kwargs...,
464489
)
465490

@@ -469,11 +494,9 @@ function mcmcsample(
469494
# Return the new chain.
470495
return chain
471496
end
472-
chains = if init_params === nothing
473-
Distributed.pmap(sample_chain, pool, seeds)
474-
else
475-
Distributed.pmap(sample_chain, pool, seeds, init_params)
476-
end
497+
chains = Distributed.pmap(
498+
sample_chain, pool, seeds, _initial_params, _initial_state
499+
)
477500
finally
478501
# Stop updating the progress bar.
479502
progress && put!(channel, false)
@@ -494,22 +517,29 @@ function mcmcsample(
494517
N::Integer,
495518
nchains::Integer;
496519
progressname="Sampling",
497-
init_params=nothing,
520+
initial_params=nothing,
521+
initial_state=nothing,
498522
kwargs...,
499523
)
500524
# Check if the number of chains is larger than the number of samples
501525
if nchains > N
502526
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
503527
end
504528

505-
# Ensure that initial parameters are `nothing` or of the correct length
506-
check_initial_params(init_params, nchains)
529+
# Ensure that initial parameters and states are `nothing` or of the correct length
530+
check_initial_params(initial_params, nchains)
531+
check_initial_state(initial_state, nchains)
532+
533+
_initial_params =
534+
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
535+
_initial_state =
536+
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
507537

508538
# Create a seed for each chain using the provided random number generator.
509539
seeds = rand(rng, UInt, nchains)
510540

511541
# Sample the chains.
512-
function sample_chain(i, seed, init_params=nothing)
542+
function sample_chain(i, seed, initial_params, initial_state)
513543
# Seed a new random number generator with the pre-made seed.
514544
Random.seed!(rng, seed)
515545

@@ -520,16 +550,13 @@ function mcmcsample(
520550
sampler,
521551
N;
522552
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
523-
init_params=init_params,
553+
initial_params=initial_params,
554+
initial_state=initial_state,
524555
kwargs...,
525556
)
526557
end
527558

528-
chains = if init_params === nothing
529-
map(sample_chain, 1:nchains, seeds)
530-
else
531-
map(sample_chain, 1:nchains, seeds, init_params)
532-
end
559+
chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state)
533560

534561
# Concatenate the chains together.
535562
return chainsstack(tighten_eltype(chains))
@@ -543,7 +570,6 @@ tighten_eltype(x::Vector{Any}) = map(identity, x)
543570
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
544571
),
545572
)
546-
547573
check_initial_params(::Nothing, n) = nothing
548574
function check_initial_params(x::AbstractArray, n)
549575
if length(x) != n
@@ -556,3 +582,21 @@ function check_initial_params(x::AbstractArray, n)
556582

557583
return nothing
558584
end
585+
586+
@nospecialize check_initial_state(x, n) = throw(
587+
ArgumentError(
588+
"initial states must be specified as a vector of length equal to the number of chains or `nothing`",
589+
),
590+
)
591+
check_initial_state(::Nothing, n) = nothing
592+
function check_initial_state(x::AbstractArray, n)
593+
if length(x) != n
594+
throw(
595+
ArgumentError(
596+
"incorrect number of initial states (expected $n, received $(length(x))"
597+
),
598+
)
599+
end
600+
601+
return nothing
602+
end

0 commit comments

Comments
 (0)