|
366 | 366 | @test all(length(x) == N for x in chains) |
367 | 367 |
|
368 | 368 | Random.seed!(1234) |
369 | | - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) |
| 369 | + chains = sample( |
| 370 | + MyModel(), |
| 371 | + MySampler(), |
| 372 | + MCMCSerial(), |
| 373 | + N, |
| 374 | + 1000; |
| 375 | + chain_type=MyChain, |
| 376 | + progress=false, |
| 377 | + ) |
370 | 378 |
|
371 | 379 | # Test output type and size. |
372 | 380 | @test chains isa Vector{<:MyChain} |
|
382 | 390 |
|
383 | 391 | # Test reproducibility. |
384 | 392 | Random.seed!(1234) |
385 | | - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) |
| 393 | + chains2 = sample( |
| 394 | + MyModel(), |
| 395 | + MySampler(), |
| 396 | + MCMCSerial(), |
| 397 | + N, |
| 398 | + 1000; |
| 399 | + chain_type=MyChain, |
| 400 | + progress=false, |
| 401 | + ) |
386 | 402 | @test all(ismissing(c.as[1]) for c in chains2) |
387 | 403 | @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) |
388 | 404 | @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) |
|
660 | 676 | end |
661 | 677 |
|
662 | 678 | @testset "Providing initial state" begin |
663 | | - function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...) |
664 | | - put!(states_channel, state) |
| 679 | + function record_state( |
| 680 | + rng, model, sampler, sample, state, i; states_channel, kwargs... |
| 681 | + ) |
| 682 | + return put!(states_channel, state) |
665 | 683 | end |
666 | 684 |
|
667 | 685 | initial_state = 10 |
|
670 | 688 | n = 10 |
671 | 689 | states_channel = Channel{Int}(n) |
672 | 690 | chain = sample( |
673 | | - MyModel(), MySampler(), n; |
| 691 | + MyModel(), |
| 692 | + MySampler(), |
| 693 | + n; |
674 | 694 | initial_state=initial_state, |
675 | 695 | callback=record_state, |
676 | | - states_channel=states_channel |
| 696 | + states_channel=states_channel, |
677 | 697 | ) |
678 | 698 |
|
679 | 699 | # Extract the states. |
|
684 | 704 | end |
685 | 705 | end |
686 | 706 |
|
687 | | - @testset "sample with $mode" for mode in [ |
688 | | - MCMCSerial(), |
689 | | - MCMCThreads(), |
690 | | - MCMCDistributed(), |
691 | | - ] |
| 707 | + @testset "sample with $mode" for mode in |
| 708 | + [MCMCSerial(), MCMCThreads(), MCMCDistributed()] |
692 | 709 | nchains = 4 |
693 | 710 | initial_state = 10 |
694 | 711 | states_channel = if mode === MCMCDistributed() |
|
698 | 715 | Channel{Int}(nchains) |
699 | 716 | end |
700 | 717 | chain = sample( |
701 | | - MyModel(), MySampler(), mode, 1, nchains; |
| 718 | + MyModel(), |
| 719 | + MySampler(), |
| 720 | + mode, |
| 721 | + 1, |
| 722 | + nchains; |
702 | 723 | initial_state=FillArrays.Fill(initial_state, nchains), |
703 | 724 | callback=record_state, |
704 | | - states_channel=states_channel |
| 725 | + states_channel=states_channel, |
705 | 726 | ) |
706 | 727 |
|
707 | 728 | # Extract the states. |
708 | 729 | states = [take!(states_channel) for _ in 1:nchains] |
709 | 730 | @test length(states) == nchains |
710 | | - for i = 1:nchains |
| 731 | + for i in 1:nchains |
711 | 732 | @test states[i] == initial_state + 1 |
712 | 733 | end |
713 | 734 | end |
|
0 commit comments