Skip to content

Commit 33487da

Browse files
devmotiongithub-actions[bot]torfjelde
authored
Support log density functions as models (#113)
* Update sample.jl * Update sample.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update api.md * Update stepper.jl * Update transducer.jl * Update api.md * Update src/stepper.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/transducer.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Update src/sample.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Reorganize fallbacks * Add tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Define utilities on all workers * Update test/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent 2d31f09 commit 33487da

10 files changed

+296
-34
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "4.3.0"
6+
version = "4.4.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

docs/src/api.md

+25-1
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,39 @@
22

33
AbstractMCMC defines an interface for sampling Markov chains.
44

5+
## Model
6+
7+
```@docs
8+
AbstractMCMC.AbstractModel
9+
AbstractMCMC.LogDensityModel
10+
```
11+
12+
## Sampler
13+
14+
```@docs
15+
AbstractMCMC.AbstractSampler
16+
```
17+
518
## Sampling a single chain
619

720
```@docs
8-
AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Integer)
921
AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Any)
22+
AbstractMCMC.sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler, ::Any)
23+
1024
```
1125

1226
### Iterator
1327

1428
```@docs
1529
AbstractMCMC.steps(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler)
30+
AbstractMCMC.steps(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler)
1631
```
1732

1833
### Transducer
1934

2035
```@docs
2136
AbstractMCMC.Sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler)
37+
AbstractMCMC.Sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler)
2238
```
2339

2440
## Sampling multiple chains in parallel
@@ -32,6 +48,14 @@ AbstractMCMC.sample(
3248
::Integer,
3349
::Integer,
3450
)
51+
AbstractMCMC.sample(
52+
::AbstractRNG,
53+
::Any,
54+
::AbstractMCMC.AbstractSampler,
55+
::AbstractMCMC.AbstractMCMCEnsemble,
56+
::Integer,
57+
::Integer,
58+
)
3559
```
3660

3761
Two algorithms are provided for parallel sampling with multiple threads and multiple processes, and one allows for the user to sample multiple chains in serial (no parallelization):

src/logdensityproblems.jl

+92
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,95 @@ struct LogDensityModel{L} <: AbstractModel
2525
end
2626

2727
LogDensityModel(logdensity::L) where {L} = LogDensityModel{L}(logdensity)
28+
29+
# Fallbacks: Wrap log density function in a model
30+
"""
31+
sample(
32+
rng::Random.AbstractRNG=Random.default_rng(),
33+
logdensity,
34+
sampler::AbstractSampler,
35+
N_or_isdone;
36+
kwargs...,
37+
)
38+
39+
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`.
40+
41+
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
42+
"""
43+
function StatsBase.sample(
44+
rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...
45+
)
46+
return StatsBase.sample(rng, _model(logdensity), sampler, N_or_isdone; kwargs...)
47+
end
48+
49+
"""
50+
sample(
51+
rng::Random.AbstractRNG=Random.default_rng(),
52+
logdensity,
53+
sampler::AbstractSampler,
54+
parallel::AbstractMCMCEnsemble,
55+
N::Integer,
56+
nchains::Integer;
57+
kwargs...,
58+
)
59+
60+
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`.
61+
62+
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
63+
"""
64+
function StatsBase.sample(
65+
rng::Random.AbstractRNG,
66+
logdensity,
67+
sampler::AbstractSampler,
68+
parallel::AbstractMCMCEnsemble,
69+
N::Integer,
70+
nchains::Integer;
71+
kwargs...,
72+
)
73+
return StatsBase.sample(
74+
rng, _model(logdensity), sampler, parallel, N, nchains; kwargs...
75+
)
76+
end
77+
78+
"""
79+
steps(
80+
rng::Random.AbstractRNG=Random.default_rng(),
81+
logdensity,
82+
sampler::AbstractSampler;
83+
kwargs...,
84+
)
85+
86+
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `steps` with the resulting model instead of `logdensity`.
87+
88+
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
89+
"""
90+
function steps(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...)
91+
return steps(rng, _model(logdensity), sampler; kwargs...)
92+
end
93+
94+
"""
95+
Sample(
96+
rng::Random.AbstractRNG=Random.default_rng(),
97+
logdensity,
98+
sampler::AbstractSampler;
99+
kwargs...,
100+
)
101+
102+
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `Sample` with the resulting model instead of `logdensity`.
103+
104+
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
105+
"""
106+
function Sample(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...)
107+
return Sample(rng, _model(logdensity), sampler; kwargs...)
108+
end
109+
110+
function _model(logdensity)
111+
if LogDensityProblems.capabilities(logdensity) === nothing
112+
throw(
113+
ArgumentError(
114+
"the log density function does not support the LogDensityProblems.jl interface. Please implement the interface or provide a model of type `AbstractMCMC.AbstractModel`",
115+
),
116+
)
117+
end
118+
return LogDensityModel(logdensity)
119+
end

src/sample.jl

+29-24
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,29 @@ function setprogress!(progress::Bool)
1212
return progress
1313
end
1414

15-
function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...)
16-
return StatsBase.sample(Random.default_rng(), model, sampler, arg; kwargs...)
17-
end
18-
19-
"""
20-
sample([rng, ]model, sampler, N; kwargs...)
21-
22-
Return `N` samples from the `model` with the Markov chain Monte Carlo `sampler`.
23-
"""
2415
function StatsBase.sample(
25-
rng::Random.AbstractRNG,
26-
model::AbstractModel,
27-
sampler::AbstractSampler,
28-
N::Integer;
29-
kwargs...,
16+
model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...
3017
)
31-
return mcmcsample(rng, model, sampler, N; kwargs...)
18+
return StatsBase.sample(
19+
Random.default_rng(), model_or_logdensity, sampler, N_or_isdone; kwargs...
20+
)
3221
end
3322

3423
"""
35-
sample([rng, ]model, sampler, isdone; kwargs...)
24+
sample(
25+
rng::Random.AbatractRNG=Random.default_rng(),
26+
model::AbstractModel,
27+
sampler::AbstractSampler,
28+
N_or_isdone;
29+
kwargs...,
30+
)
31+
32+
Sample from the `model` with the Markov chain Monte Carlo `sampler` and return the samples.
3633
37-
Sample from the `model` with the Markov chain Monte Carlo `sampler` until a
38-
convergence criterion `isdone` returns `true`, and return the samples.
34+
If `N_or_isdone` is an `Integer`, exactly `N_or_isdone` samples are returned.
3935
40-
The function `isdone` has the signature
36+
Otherwise, sampling is performed until a convergence criterion `N_or_isdone` returns `true`.
37+
The convergence criterion has to be a function with the signature
4138
```julia
4239
isdone(rng, model, sampler, samples, state, iteration; kwargs...)
4340
```
@@ -48,27 +45,35 @@ function StatsBase.sample(
4845
rng::Random.AbstractRNG,
4946
model::AbstractModel,
5047
sampler::AbstractSampler,
51-
isdone;
48+
N_or_isdone;
5249
kwargs...,
5350
)
54-
return mcmcsample(rng, model, sampler, isdone; kwargs...)
51+
return mcmcsample(rng, model, sampler, N_or_isdone; kwargs...)
5552
end
5653

5754
function StatsBase.sample(
58-
model::AbstractModel,
55+
model_or_logdensity,
5956
sampler::AbstractSampler,
6057
parallel::AbstractMCMCEnsemble,
6158
N::Integer,
6259
nchains::Integer;
6360
kwargs...,
6461
)
6562
return StatsBase.sample(
66-
Random.default_rng(), model, sampler, parallel, N, nchains; kwargs...
63+
Random.default_rng(), model_or_logdensity, sampler, parallel, N, nchains; kwargs...
6764
)
6865
end
6966

7067
"""
71-
sample([rng, ]model, sampler, parallel, N, nchains; kwargs...)
68+
sample(
69+
rng::Random.AbstractRNG=Random.default_rng(),
70+
model::AbstractModel,
71+
sampler::AbstractSampler,
72+
parallel::AbstractMCMCEnsemble,
73+
N::Integer,
74+
nchains::Integer;
75+
kwargs...,
76+
)
7277
7378
Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel
7479
using the `parallel` algorithm, and combine them into a single chain.

src/stepper.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,17 @@ end
4141
Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite()
4242
Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown()
4343

44-
function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...)
45-
return steps(Random.default_rng(), model, sampler; kwargs...)
44+
function steps(model_or_logdensity, sampler::AbstractSampler; kwargs...)
45+
return steps(Random.default_rng(), model_or_logdensity, sampler; kwargs...)
4646
end
4747

4848
"""
49-
steps([rng, ]model, sampler; kwargs...)
49+
steps(
50+
rng::Random.AbstractRNG=Random.default_rng(),
51+
model::AbstractModel,
52+
sampler::AbstractSampler;
53+
kwargs...,
54+
)
5055
5156
Create an iterator that returns samples from the `model` with the Markov chain Monte Carlo
5257
`sampler`.

src/transducer.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@ struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <:
66
kwargs::K
77
end
88

9-
function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...)
10-
return Sample(Random.default_rng(), model, sampler; kwargs...)
9+
function Sample(model_or_logdensity, sampler::AbstractSampler; kwargs...)
10+
return Sample(Random.default_rng(), model_or_logdensity, sampler; kwargs...)
1111
end
1212

1313
"""
14-
Sample([rng, ]model, sampler; kwargs...)
14+
Sample(
15+
rng::Random.AbstractRNG=Random.default_rng(),
16+
model::AbstractModel,
17+
sampler::AbstractSampler;
18+
kwargs...,
19+
)
1520
1621
Create a transducer that returns samples from the `model` with the Markov chain Monte Carlo
1722
`sampler`.

test/logdensityproblems.jl

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
@testset "logdensityproblems.jl" begin
2+
# Add worker processes.
3+
# Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced
4+
# See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366
5+
pids = addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int)
6+
7+
# Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random).
8+
@everywhere begin
9+
using AbstractMCMC
10+
using AbstractMCMC: sample
11+
using LogDensityProblems
12+
13+
using Logging
14+
using Random
15+
include("utils.jl")
16+
end
17+
18+
@testset "LogDensityModel" begin
19+
= MyLogDensity(10)
20+
model = @inferred AbstractMCMC.LogDensityModel(ℓ)
21+
@test model isa AbstractMCMC.LogDensityModel{MyLogDensity}
22+
@test model.logdensity ===
23+
24+
@test_throws ArgumentError AbstractMCMC.LogDensityModel(mylogdensity)
25+
end
26+
27+
@testset "fallback for log densities" begin
28+
# Sample with log density
29+
dim = 10
30+
= MyLogDensity(dim)
31+
Random.seed!(1234)
32+
N = 1_000
33+
samples = sample(ℓ, MySampler(), N)
34+
35+
# Samples are of the correct dimension and log density values are correct
36+
@test length(samples) == N
37+
@test all(length(x.a) == dim for x in samples)
38+
@test all(x.b LogDensityProblems.logdensity(ℓ, x.a) for x in samples)
39+
40+
# Same chain as if LogDensityModel is used explicitly
41+
Random.seed!(1234)
42+
samples2 = sample(AbstractMCMC.LogDensityModel(ℓ), MySampler(), N)
43+
@test length(samples2) == N
44+
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples2))
45+
46+
# Same chain if sampling is performed with convergence criterion
47+
Random.seed!(1234)
48+
isdone(rng, model, sampler, state, samples, iteration; kwargs...) = iteration > N
49+
samples3 = sample(ℓ, MySampler(), isdone)
50+
@test length(samples3) == N
51+
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples3))
52+
53+
# Same chain if sampling is performed with iterator
54+
Random.seed!(1234)
55+
samples4 = collect(Iterators.take(AbstractMCMC.steps(ℓ, MySampler()), N))
56+
@test length(samples4) == N
57+
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples4))
58+
59+
# Same chain if sampling is performed with transducer
60+
Random.seed!(1234)
61+
xf = AbstractMCMC.Sample(ℓ, MySampler())
62+
samples5 = collect(xf(1:N))
63+
@test length(samples5) == N
64+
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples5))
65+
66+
# Parallel sampling
67+
for alg in (MCMCSerial(), MCMCDistributed(), MCMCThreads())
68+
chains = sample(ℓ, MySampler(), alg, N, 2)
69+
@test length(chains) == 2
70+
samples = vcat(chains[1], chains[2])
71+
@test length(samples) == 2 * N
72+
@test all(length(x.a) == dim for x in samples)
73+
@test all(x.b LogDensityProblems.logdensity(ℓ, x.a) for x in samples)
74+
end
75+
76+
# Log density has to satisfy the LogDensityProblems interface
77+
@test_throws ArgumentError sample(mylogdensity, MySampler(), N)
78+
@test_throws ArgumentError sample(mylogdensity, MySampler(), isdone)
79+
@test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCSerial(), N, 2)
80+
@test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCThreads(), N, 2)
81+
@test_throws ArgumentError sample(
82+
mylogdensity, MySampler(), MCMCDistributed(), N, 2
83+
)
84+
@test_throws ArgumentError AbstractMCMC.steps(mylogdensity, MySampler())
85+
@test_throws ArgumentError AbstractMCMC.Sample(mylogdensity, MySampler())
86+
end
87+
88+
# Remove workers
89+
rmprocs(pids...)
90+
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using AbstractMCMC
22
using Atom.Progress: JunoProgressLogger
33
using ConsoleProgressMonitor: ProgressLogger
44
using IJulia
5+
using LogDensityProblems
56
using LoggingExtras: TeeLogger, EarlyFilteredLogger
67
using TerminalLoggers: TerminalLogger
78
using Transducers
@@ -22,4 +23,5 @@ include("utils.jl")
2223
include("sample.jl")
2324
include("stepper.jl")
2425
include("transducer.jl")
26+
include("logdensityproblems.jl")
2527
end

0 commit comments

Comments
 (0)