Skip to content

Commit d29f810

Browse files
committed
Merge remote-tracking branch 'origin/main' into breaking
2 parents ef5a0d1 + 7c12485 commit d29f810

File tree

15 files changed

+475
-475
lines changed

15 files changed

+475
-475
lines changed

.github/workflows/Tests.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,12 @@ jobs:
2626
# not included in the others.
2727
- name: "mcmc/gibbs"
2828
args: "mcmc/gibbs.jl"
29-
- name: "mcmc/hmc"
30-
args: "mcmc/hmc.jl"
31-
- name: "mcmc/abstractmcmc"
32-
args: "mcmc/abstractmcmc.jl"
3329
- name: "mcmc/Inference"
3430
args: "mcmc/Inference.jl"
35-
- name: "mcmc/ess"
36-
args: "mcmc/ess.jl"
31+
- name: "ad"
32+
args: "ad.jl"
3733
- name: "everything else"
38-
args: "--skip mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl"
34+
args: "--skip mcmc/gibbs.jl mcmc/Inference.jl ad.jl"
3935
runner:
4036
# Default
4137
- version: '1'

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Release 0.38.4
2+
3+
The minimum Julia version was increased to 1.10.2 (from 1.10.0).
4+
On versions before 1.10.2, `sample()` took an excessively long time to run (probably due to compilation).
5+
16
# Release 0.38.3
27

38
`getparams(::Model, ::AbstractVarInfo)` now returns an empty `Float64[]` if the VarInfo contains no parameters.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Statistics = "1.6"
8383
StatsAPI = "1.6"
8484
StatsBase = "0.32, 0.33, 0.34"
8585
StatsFuns = "0.8, 0.9, 1"
86-
julia = "1.10"
86+
julia = "1.10.2"
8787

8888
[extras]
8989
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"

src/mcmc/hmc.jl

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,36 @@ function AbstractMCMC.sample(
138138
end
139139
end
140140

141+
function find_initial_params(
142+
rng::Random.AbstractRNG,
143+
model::DynamicPPL.Model,
144+
varinfo::DynamicPPL.AbstractVarInfo,
145+
hamiltonian::AHMC.Hamiltonian;
146+
max_attempts::Int=1000,
147+
)
148+
varinfo = deepcopy(varinfo) # Don't mutate
149+
150+
for attempts in 1:max_attempts
151+
theta = varinfo[:]
152+
z = AHMC.phasepoint(rng, theta, hamiltonian)
153+
isfinite(z) && return varinfo, z
154+
155+
attempts == 10 &&
156+
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing explicit initial parameters using the `initial_params` keyword"
157+
158+
# Resample and try again.
159+
# NOTE: varinfo has to be linked to make sure this samples in unconstrained space
160+
varinfo = last(
161+
DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform())
162+
)
163+
end
164+
165+
# if we failed to find valid initial parameters, error
166+
return error(
167+
"failed to find valid initial parameters in $(max_attempts) tries. This may indicate an error with the model or AD backend; please open an issue at https://github.com/TuringLang/Turing.jl/issues",
168+
)
169+
end
170+
141171
function DynamicPPL.initialstep(
142172
rng::AbstractRNG,
143173
model::AbstractModel,
@@ -170,33 +200,14 @@ function DynamicPPL.initialstep(
170200
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
171201
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
172202

173-
# Compute phase point z.
174-
z = AHMC.phasepoint(rng, theta, hamiltonian)
175-
176203
# If no initial parameters are provided, resample until the log probability
177-
# and its gradient are finite.
178-
if initial_params === nothing
179-
init_attempt_count = 1
180-
while !isfinite(z)
181-
if init_attempt_count == 10
182-
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
183-
end
184-
if init_attempt_count == 1000
185-
error(
186-
"failed to find valid initial parameters in $(init_attempt_count) tries. This may indicate an error with the model or AD backend; please open an issue at https://github.com/TuringLang/Turing.jl/issues",
187-
)
188-
end
189-
190-
# NOTE: This will sample in the unconstrained space.
191-
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
192-
theta = vi[:]
193-
194-
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
195-
z = AHMC.phasepoint(rng, theta, hamiltonian)
196-
197-
init_attempt_count += 1
198-
end
204+
# and its gradient are finite. Otherwise, just use the existing parameters.
205+
vi, z = if initial_params === nothing
206+
find_initial_params(rng, model, vi, hamiltonian)
207+
else
208+
vi, AHMC.phasepoint(rng, theta, hamiltonian)
199209
end
210+
theta = vi[:]
200211

201212
# Cache current log density.
202213
log_density_old = getlogp(vi)

test/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
34
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
45
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
@@ -20,7 +21,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2021
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2122
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2223
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
23-
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2424
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
2525
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
2626
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -39,6 +39,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3939
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
4040

4141
[compat]
42+
ADTypes = "1"
4243
AbstractMCMC = "5"
4344
AbstractPPL = "0.9, 0.10, 0.11"
4445
AdvancedMH = "0.6, 0.7, 0.8"
@@ -52,15 +53,14 @@ Combinatorics = "1"
5253
Distributions = "0.25"
5354
DistributionsAD = "0.6.3"
5455
DynamicHMC = "2.1.6, 3.0"
55-
DynamicPPL = "0.36"
56+
DynamicPPL = "0.36.6"
5657
FiniteDifferences = "0.10.8, 0.11, 0.12"
5758
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
5859
HypothesisTests = "0.11"
5960
LinearAlgebra = "1"
6061
LogDensityProblems = "2"
6162
LogDensityProblemsAD = "1.4"
6263
MCMCChains = "5, 6, 7"
63-
Mooncake = "0.4.95"
6464
NamedArrays = "0.9.4, 0.10"
6565
Optim = "1"
6666
Optimization = "3, 4"

0 commit comments

Comments
 (0)