|
658 | 658 | )
|
659 | 659 | @test it_array == collect(1:size(chain, 1))
|
660 | 660 | end
|
| 661 | + |
| 662 | + @testset "Providing initial state" begin |
| 663 | + function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...) |
| 664 | + put!(states_channel, state) |
| 665 | + end |
| 666 | + |
| 667 | + initial_state = 10 |
| 668 | + |
| 669 | + @testset "sample" begin |
| 670 | + n = 10 |
| 671 | + states_channel = Channel{Int}(n) |
| 672 | + chain = sample( |
| 673 | + MyModel(), MySampler(), n; |
| 674 | + initial_state=initial_state, |
| 675 | + callback=record_state, |
| 676 | + states_channel=states_channel |
| 677 | + ) |
| 678 | + |
| 679 | + # Extract the states. |
| 680 | + states = [take!(states_channel) for _ in 1:n] |
| 681 | + @test length(states) == n |
| 682 | + for i in 1:n |
| 683 | + @test states[i] == initial_state + i |
| 684 | + end |
| 685 | + end |
| 686 | + |
| 687 | + @testset "sample with $mode" for mode in [ |
| 688 | + MCMCSerial(), |
| 689 | + MCMCThreads(), |
| 690 | + MCMCDistributed(), |
| 691 | + ] |
| 692 | + nchains = 4 |
| 693 | + initial_state = 10 |
| 694 | + states_channel = if mode === MCMCDistributed() |
| 695 | + # Need to use `RemoteChannel` for this. |
| 696 | + RemoteChannel(() -> Channel{Int}(nchains)) |
| 697 | + else |
| 698 | + Channel{Int}(nchains) |
| 699 | + end |
| 700 | + chain = sample( |
| 701 | + MyModel(), MySampler(), mode, 1, nchains; |
| 702 | + initial_state=FillArrays.Fill(initial_state, nchains), |
| 703 | + callback=record_state, |
| 704 | + states_channel=states_channel |
| 705 | + ) |
| 706 | + |
| 707 | + # Extract the states. |
| 708 | + states = [take!(states_channel) for _ in 1:nchains] |
| 709 | + @test length(states) == nchains |
| 710 | + for i = 1:nchains |
| 711 | + @test states[i] == initial_state + 1 |
| 712 | + end |
| 713 | + end |
| 714 | + end |
661 | 715 | end
|
0 commit comments