Skip to content

Commit 5cc1585

Browse files
committed
Fix init_params=nothing (#401)
Fixes the test errors in TuringLang/Turing.jl#1799. Co-authored-by: David Widmann <[email protected]>
1 parent 748b191 commit 5cc1585

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.19.0"
3+
version = "0.19.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/sampler.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ end
6969

7070
# initial step: general interface for resuming and
7171
function AbstractMCMC.step(
72-
rng::Random.AbstractRNG, model::Model, spl::Sampler; resume_from=nothing, kwargs...
72+
rng::Random.AbstractRNG,
73+
model::Model,
74+
spl::Sampler;
75+
resume_from=nothing,
76+
init_params=nothing,
77+
kwargs...,
7378
)
7479
if resume_from !== nothing
7580
state = loadstate(resume_from)
@@ -81,8 +86,8 @@ function AbstractMCMC.step(
8186
vi = VarInfo(rng, model, _spl)
8287

8388
# Update the parameters if provided.
84-
if haskey(kwargs, :init_params)
85-
vi = initialize_parameters!!(vi, kwargs[:init_params], spl)
89+
if init_params !== nothing
90+
vi = initialize_parameters!!(vi, init_params, spl)
8691

8792
# Update joint log probability.
8893
# TODO: fix properly by using sampler and evaluation contexts
@@ -96,7 +101,7 @@ function AbstractMCMC.step(
96101
end
97102
end
98103

99-
return initialstep(rng, model, spl, vi; kwargs...)
104+
return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...)
100105
end
101106

102107
"""

test/sampler.jl

+20
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,26 @@
125125
@test !ismissing(c[1].metadata.s.vals[1])
126126
@test c[1].metadata.m.vals == [-1]
127127
end
128+
129+
# specify `init_params=nothing`
130+
Random.seed!(1234)
131+
chain1 = sample(model, sampler, 1; progress=false)
132+
Random.seed!(1234)
133+
chain2 = sample(model, sampler, 1; init_params=nothing, progress=false)
134+
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
135+
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals
136+
137+
# parallel sampling
138+
Random.seed!(1234)
139+
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
140+
Random.seed!(1234)
141+
chains2 = sample(
142+
model, sampler, MCMCThreads(), 1, 10; init_params=nothing, progress=false
143+
)
144+
for (c1, c2) in zip(chains1, chains2)
145+
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
146+
@test c1[1].metadata.s.vals == c2[1].metadata.s.vals
147+
end
128148
end
129149
end
130150
end

0 commit comments

Comments
 (0)