@@ -528,11 +528,14 @@ function mcmcsample(
528
528
check_initial_params (initial_params, nchains)
529
529
check_initial_state (initial_state, nchains)
530
530
531
+ _initial_params = initial_params === nothing ? FillArrays. Fill (nothing , nchains) : initial_params
532
+ _initial_state = initial_state === nothing ? FillArrays. Fill (nothing , nchains) : initial_state
533
+
531
534
# Create a seed for each chain using the provided random number generator.
532
535
seeds = rand (rng, UInt, nchains)
533
536
534
537
# Sample the chains.
535
- function sample_chain (i, seed, initial_params= nothing )
538
+ function sample_chain (i, seed, initial_params, initial_state )
536
539
# Seed a new random number generator with the pre-made seed.
537
540
Random. seed! (rng, seed)
538
541
@@ -544,15 +547,12 @@ function mcmcsample(
544
547
N;
545
548
progressname= string (progressname, " (Chain " , i, " of " , nchains, " )" ),
546
549
initial_params= initial_params,
550
+ initial_state= initial_state,
547
551
kwargs... ,
548
552
)
549
553
end
550
554
551
- chains = if initial_params === nothing
552
- map (sample_chain, 1 : nchains, seeds)
553
- else
554
- map (sample_chain, 1 : nchains, seeds, initial_params)
555
- end
555
+ chains = map (sample_chain, 1 : nchains, seeds, _initial_params, _initial_state)
556
556
557
557
# Concatenate the chains together.
558
558
return chainsstack (tighten_eltype (chains))
0 commit comments