|
366 | 366 | @test all(length(x) == N for x in chains)
|
367 | 367 |
|
368 | 368 | Random.seed!(1234)
|
369 |
| - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) |
| 369 | + chains = sample( |
| 370 | + MyModel(), |
| 371 | + MySampler(), |
| 372 | + MCMCSerial(), |
| 373 | + N, |
| 374 | + 1000; |
| 375 | + chain_type=MyChain, |
| 376 | + progress=false, |
| 377 | + ) |
370 | 378 |
|
371 | 379 | # Test output type and size.
|
372 | 380 | @test chains isa Vector{<:MyChain}
|
|
382 | 390 |
|
383 | 391 | # Test reproducibility.
|
384 | 392 | Random.seed!(1234)
|
385 |
| - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) |
| 393 | + chains2 = sample( |
| 394 | + MyModel(), |
| 395 | + MySampler(), |
| 396 | + MCMCSerial(), |
| 397 | + N, |
| 398 | + 1000; |
| 399 | + chain_type=MyChain, |
| 400 | + progress=false, |
| 401 | + ) |
386 | 402 | @test all(ismissing(c.as[1]) for c in chains2)
|
387 | 403 | @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
|
388 | 404 | @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
|
|
660 | 676 | end
|
661 | 677 |
|
662 | 678 | @testset "Providing initial state" begin
|
663 |
| - function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...) |
664 |
| - put!(states_channel, state) |
| 679 | + function record_state( |
| 680 | + rng, model, sampler, sample, state, i; states_channel, kwargs... |
| 681 | + ) |
| 682 | + return put!(states_channel, state) |
665 | 683 | end
|
666 | 684 |
|
667 | 685 | initial_state = 10
|
|
670 | 688 | n = 10
|
671 | 689 | states_channel = Channel{Int}(n)
|
672 | 690 | chain = sample(
|
673 |
| - MyModel(), MySampler(), n; |
| 691 | + MyModel(), |
| 692 | + MySampler(), |
| 693 | + n; |
674 | 694 | initial_state=initial_state,
|
675 | 695 | callback=record_state,
|
676 |
| - states_channel=states_channel |
| 696 | + states_channel=states_channel, |
677 | 697 | )
|
678 | 698 |
|
679 | 699 | # Extract the states.
|
|
684 | 704 | end
|
685 | 705 | end
|
686 | 706 |
|
687 |
| - @testset "sample with $mode" for mode in [ |
688 |
| - MCMCSerial(), |
689 |
| - MCMCThreads(), |
690 |
| - MCMCDistributed(), |
691 |
| - ] |
| 707 | + @testset "sample with $mode" for mode in |
| 708 | + [MCMCSerial(), MCMCThreads(), MCMCDistributed()] |
692 | 709 | nchains = 4
|
693 | 710 | initial_state = 10
|
694 | 711 | states_channel = if mode === MCMCDistributed()
|
|
698 | 715 | Channel{Int}(nchains)
|
699 | 716 | end
|
700 | 717 | chain = sample(
|
701 |
| - MyModel(), MySampler(), mode, 1, nchains; |
| 718 | + MyModel(), |
| 719 | + MySampler(), |
| 720 | + mode, |
| 721 | + 1, |
| 722 | + nchains; |
702 | 723 | initial_state=FillArrays.Fill(initial_state, nchains),
|
703 | 724 | callback=record_state,
|
704 |
| - states_channel=states_channel |
| 725 | + states_channel=states_channel, |
705 | 726 | )
|
706 | 727 |
|
707 | 728 | # Extract the states.
|
708 | 729 | states = [take!(states_channel) for _ in 1:nchains]
|
709 | 730 | @test length(states) == nchains
|
710 |
| - for i = 1:nchains |
| 731 | + for i in 1:nchains |
711 | 732 | @test states[i] == initial_state + 1
|
712 | 733 | end
|
713 | 734 | end
|
|
0 commit comments