Skip to content

Commit 5d83ab4

Browse files
committed
added tests for initial state
1 parent 56456c9 commit 5d83ab4

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

test/sample.jl

+54
Original file line numberDiff line numberDiff line change
@@ -658,4 +658,58 @@
658658
)
659659
@test it_array == collect(1:size(chain, 1))
660660
end
661+
662+
@testset "Providing initial state" begin
663+
function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...)
664+
put!(states_channel, state)
665+
end
666+
667+
initial_state = 10
668+
669+
@testset "sample" begin
670+
n = 10
671+
states_channel = Channel{Int}(n)
672+
chain = sample(
673+
MyModel(), MySampler(), n;
674+
initial_state=initial_state,
675+
callback=record_state,
676+
states_channel=states_channel
677+
)
678+
679+
# Extract the states.
680+
states = [take!(states_channel) for _ in 1:n]
681+
@test length(states) == n
682+
for i in 1:n
683+
@test states[i] == initial_state + i
684+
end
685+
end
686+
687+
@testset "sample with $mode" for mode in [
688+
MCMCSerial(),
689+
MCMCThreads(),
690+
MCMCDistributed(),
691+
]
692+
nchains = 4
693+
initial_state = 10
694+
states_channel = if mode === MCMCDistributed()
695+
# Need to use `RemoteChannel` for this.
696+
RemoteChannel(() -> Channel{Int}(nchains))
697+
else
698+
Channel{Int}(nchains)
699+
end
700+
chain = sample(
701+
MyModel(), MySampler(), mode, 1, nchains;
702+
initial_state=FillArrays.Fill(initial_state, nchains),
703+
callback=record_state,
704+
states_channel=states_channel
705+
)
706+
707+
# Extract the states.
708+
states = [take!(states_channel) for _ in 1:nchains]
709+
@test length(states) == nchains
710+
for i = 1:nchains
711+
@test states[i] == initial_state + 1
712+
end
713+
end
714+
end
661715
end

0 commit comments

Comments
 (0)