Skip to content

Commit 56456c9

Browse files
committed
mroe fixes
1 parent 74465b4 commit 56456c9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/sample.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -528,11 +528,14 @@ function mcmcsample(
528528
check_initial_params(initial_params, nchains)
529529
check_initial_state(initial_state, nchains)
530530

531+
_initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
532+
_initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
533+
531534
# Create a seed for each chain using the provided random number generator.
532535
seeds = rand(rng, UInt, nchains)
533536

534537
# Sample the chains.
535-
function sample_chain(i, seed, initial_params=nothing)
538+
function sample_chain(i, seed, initial_params, initial_state)
536539
# Seed a new random number generator with the pre-made seed.
537540
Random.seed!(rng, seed)
538541

@@ -544,15 +547,12 @@ function mcmcsample(
544547
N;
545548
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
546549
initial_params=initial_params,
550+
initial_state=initial_state,
547551
kwargs...,
548552
)
549553
end
550554

551-
chains = if initial_params === nothing
552-
map(sample_chain, 1:nchains, seeds)
553-
else
554-
map(sample_chain, 1:nchains, seeds, initial_params)
555-
end
555+
chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state)
556556

557557
# Concatenate the chains together.
558558
return chainsstack(tighten_eltype(chains))

0 commit comments

Comments
 (0)