Skip to content

Commit 3ed5314

Browse files
Apply suggestions from code review
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 5d83ab4 commit 3ed5314

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

src/sample.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,10 @@ function mcmcsample(
432432
check_initial_params(initial_params, nchains)
433433
check_initial_state(initial_state, nchains)
434434

435-
_initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
436-
_initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
435+
_initial_params =
436+
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
437+
_initial_state =
438+
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
437439

438440
# Create a seed for each chain using the provided random number generator.
439441
seeds = rand(rng, UInt, nchains)
@@ -528,8 +530,10 @@ function mcmcsample(
528530
check_initial_params(initial_params, nchains)
529531
check_initial_state(initial_state, nchains)
530532

531-
_initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
532-
_initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
533+
_initial_params =
534+
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
535+
_initial_state =
536+
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
533537

534538
# Create a seed for each chain using the provided random number generator.
535539
seeds = rand(rng, UInt, nchains)

test/sample.jl

+35-14
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,15 @@
366366
@test all(length(x) == N for x in chains)
367367

368368
Random.seed!(1234)
369-
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false)
369+
chains = sample(
370+
MyModel(),
371+
MySampler(),
372+
MCMCSerial(),
373+
N,
374+
1000;
375+
chain_type=MyChain,
376+
progress=false,
377+
)
370378

371379
# Test output type and size.
372380
@test chains isa Vector{<:MyChain}
@@ -382,7 +390,15 @@
382390

383391
# Test reproducibility.
384392
Random.seed!(1234)
385-
chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false)
393+
chains2 = sample(
394+
MyModel(),
395+
MySampler(),
396+
MCMCSerial(),
397+
N,
398+
1000;
399+
chain_type=MyChain,
400+
progress=false,
401+
)
386402
@test all(ismissing(c.as[1]) for c in chains2)
387403
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
388404
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@@ -660,8 +676,10 @@
660676
end
661677

662678
@testset "Providing initial state" begin
663-
function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...)
664-
put!(states_channel, state)
679+
function record_state(
680+
rng, model, sampler, sample, state, i; states_channel, kwargs...
681+
)
682+
return put!(states_channel, state)
665683
end
666684

667685
initial_state = 10
@@ -670,10 +688,12 @@
670688
n = 10
671689
states_channel = Channel{Int}(n)
672690
chain = sample(
673-
MyModel(), MySampler(), n;
691+
MyModel(),
692+
MySampler(),
693+
n;
674694
initial_state=initial_state,
675695
callback=record_state,
676-
states_channel=states_channel
696+
states_channel=states_channel,
677697
)
678698

679699
# Extract the states.
@@ -684,11 +704,8 @@
684704
end
685705
end
686706

687-
@testset "sample with $mode" for mode in [
688-
MCMCSerial(),
689-
MCMCThreads(),
690-
MCMCDistributed(),
691-
]
707+
@testset "sample with $mode" for mode in
708+
[MCMCSerial(), MCMCThreads(), MCMCDistributed()]
692709
nchains = 4
693710
initial_state = 10
694711
states_channel = if mode === MCMCDistributed()
@@ -698,16 +715,20 @@
698715
Channel{Int}(nchains)
699716
end
700717
chain = sample(
701-
MyModel(), MySampler(), mode, 1, nchains;
718+
MyModel(),
719+
MySampler(),
720+
mode,
721+
1,
722+
nchains;
702723
initial_state=FillArrays.Fill(initial_state, nchains),
703724
callback=record_state,
704-
states_channel=states_channel
725+
states_channel=states_channel,
705726
)
706727

707728
# Extract the states.
708729
states = [take!(states_channel) for _ in 1:nchains]
709730
@test length(states) == nchains
710-
for i = 1:nchains
731+
for i in 1:nchains
711732
@test states[i] == initial_state + 1
712733
end
713734
end

0 commit comments

Comments
 (0)