Skip to content

Commit f86999b

Browse files
authored
Unify HMC + externalsampler + DynamicHMC parameter initialisation; re-export DynamicPPL.set_logprob_type! (#2794)
Closes #2739. As a nice by-product of using `rand(ldf)` rather than `vi[:]`, we also avoid accidentally promoting Float32 to Float64. This means that (together with TuringLang/DynamicPPL.jl#1328 and tpapp/DynamicHMC.jl#199) one can do ```julia julia> using DynamicPPL; DynamicPPL.set_logprob_type!(Float32) ┌ Info: DynamicPPL's log probability type has been set to Float32. └ Please note you will need to restart your Julia session for this change to take effect. ``` and then after restarting ```julia julia> using Turing, FlexiChains, DynamicHMC julia> @model function f() x ~ Normal(0.0f0, 1.0f0) end f (generic function with 2 methods) julia> chn = sample(f(), externalsampler(DynamicHMC.NUTS()), 100; chain_type=VNChain) Sampling 100%|████████████████████████████████████████████| Time: 0:00:02 FlexiChain (100 iterations, 1 chain) ↓ iter=1:100 | → chain=1:1 Parameter type VarName Parameters x Extra keys :logprior, :loglikelihood, :logjoint julia> eltype(chn[@varname(x)]) Float32 julia> eltype(chn[:logjoint]) Float32 ``` (Previously, the values of `x` would be Float32, but logjoint would be Float64. And if you used MCMCChains, everything would be Float64.)
1 parent 838ae6a commit f86999b

18 files changed

Lines changed: 391 additions & 163 deletions

File tree

.github/workflows/FloatTypes.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: Float type promotion
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
9+
# needed to allow julia-actions/cache to delete old caches that it has created
10+
permissions:
11+
actions: write
12+
contents: read
13+
14+
# Cancel existing tests on the same PR if a new commit is added to a pull request
15+
concurrency:
16+
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
17+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
18+
19+
jobs:
20+
floattypes:
21+
runs-on: ubuntu-latest
22+
steps:
23+
- uses: actions/checkout@v6
24+
25+
- uses: julia-actions/setup-julia@v2
26+
with:
27+
version: "1"
28+
29+
- uses: julia-actions/cache@v2
30+
31+
- name: Run float type tests
32+
working-directory: test/floattypes
33+
run: |
34+
julia --project=. --color=yes -e 'using Pkg; Pkg.instantiate()'
35+
julia --project=. --color=yes main.jl setup f64
36+
julia --project=. --color=yes main.jl run f64
37+
julia --project=. --color=yes main.jl setup f32
38+
julia --project=. --color=yes main.jl run f32
39+
julia --project=. --color=yes main.jl setup f16
40+
julia --project=. --color=yes main.jl run f16
41+
julia --project=. --color=yes main.jl setup min
42+
julia --project=. --color=yes main.jl run min

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ docs/.jekyll-cache
2525
.vscode
2626
.DS_Store
2727
Manifest.toml
28-
/Manifest.toml
29-
/test/Manifest.toml
28+
LocalPreferences.toml
3029

3130
benchmarks/output/

HISTORY.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
# 0.43.3
2+
3+
Unify parameter initialisation for HMC and external samplers.
4+
External samplers (like HMC) now attempt multiple times to generate valid initial parameters, instead of just taking the first set of parameters.
5+
6+
Re-exports `set_logprob_type!` from DynamicPPL to allow users to control the base log-probability type used when evaluating Turing models.
7+
For example, calling `set_logprob_type!(Float32)` will mean that Turing will use `Float32` for log-probability calculations, only promoting if there is something in the model that causes it to be (e.g. a distribution that returns `Float64` log-probabilities).
8+
Note that this is a compile-time preference: for it to take effect you will have to restart your Julia session after calling `set_logprob_type!`.
9+
10+
Furthermore, note that sampler support for non-`Float64` log-probabilities is currently limited.
11+
Although DynamicPPL promises not promote float types unnecessarily, many samplers, including HMC and NUTS, still use `Float64` internally and thus will cause log-probabilities and parameters to be promoted to `Float64`, even if the model itself uses `Float32`.
12+
113
# 0.43.2
214

315
Throw an `ArgumentError` when a `Gibbs` sampler is missing component samplers for any variable in the model.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.43.2"
3+
version = "0.43.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -61,7 +61,7 @@ DifferentiationInterface = "0.7"
6161
Distributions = "0.25.77"
6262
DocStringExtensions = "0.8, 0.9"
6363
DynamicHMC = "3.4"
64-
DynamicPPL = "0.40.6"
64+
DynamicPPL = "0.40.15"
6565
EllipticalSliceSampling = "0.5, 1, 2"
6666
ForwardDiff = "0.10.3, 1"
6767
Libtask = "0.9.14"

docs/src/api.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
4646
| `setthreadsafe` | [`DynamicPPL.setthreadsafe`](@extref) | Mark a model as requiring threadsafe evaluation |
4747
| `might_produce` | [`Libtask.might_produce`](@extref) | Mark a method signature as potentially calling `Libtask.produce` |
4848
| `@might_produce` | [`Libtask.@might_produce`](@extref) | Mark a function name as potentially calling `Libtask.produce` |
49+
| `set_logprob_type!` | [`DynamicPPL.set_logprob_type!`](@extref) | Set the base log-probability type used during evaluation of Turing models |
4950

5051
### Inference
5152

@@ -81,11 +82,11 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
8182

8283
### Data structures
8384

84-
| Exported symbol | Documentation | Description |
85-
|:--------------- |:------------------------------------------- |:----------------------------------- |
86-
| `@vnt` | [`DynamicPPL.@vnt`](@extref) | Generate a `VarNameTuple` |
87-
| `VarNamedTuple` | [`DynamicPPL.VarNamedTuple`](@extref) | A mapping from `VarName`s to values |
88-
| `OrderedDict` | [`OrderedCollections.OrderedDict`](@extref) | An ordered dictionary |
85+
| Exported symbol | Documentation | Description |
86+
|:--------------- |:---------------------------------------------------- |:----------------------------------- |
87+
| `@vnt` | [`DynamicPPL.VarNamedTuples.@vnt`](@extref) | Generate a `VarNameTuple` |
88+
| `VarNamedTuple` | [`DynamicPPL.VarNamedTuples.VarNamedTuple`](@extref) | A mapping from `VarName`s to values |
89+
| `OrderedDict` | [`OrderedCollections.OrderedDict`](@extref) | An ordered dictionary |
8990

9091
### DynamicPPL utilities
9192

ext/TuringDynamicHMCExt.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,23 @@ function AbstractMCMC.step(
5050
initial_params,
5151
kwargs...,
5252
)
53-
# Define log-density function.
54-
# TODO(penelopeysm) We need to check that the initial parameters are valid. Same as how
55-
# we do it for HMC
56-
_, vi = DynamicPPL.init!!(
57-
rng, model, DynamicPPL.VarInfo(), initial_params, DynamicPPL.LinkAll()
58-
)
59-
= DynamicPPL.LogDensityFunction(
60-
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
53+
# Construct LogDensityFunction
54+
tfm_strategy = DynamicPPL.LinkAll()
55+
ldf = DynamicPPL.LogDensityFunction(
56+
model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=spl.adtype
6157
)
58+
x = Turing.Inference.find_initial_params_ldf(rng, ldf, initial_params)
6259

6360
# Perform initial step.
6461
results = DynamicHMC.mcmc_keep_warmup(
65-
rng, , 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
62+
rng, ldf, 0; initialization=(q=x,), reporter=DynamicHMC.NoProgressReport()
6663
)
6764
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
6865
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
6966

7067
# Create first sample and state.
71-
sample = DynamicPPL.ParamsWithStats(Q.q, )
72-
state = DynamicNUTSState(, Q, steps.H.κ, steps.ϵ)
68+
sample = DynamicPPL.ParamsWithStats(Q.q, ldf)
69+
state = DynamicNUTSState(ldf, Q, steps.H.κ, steps.ϵ)
7370

7471
return sample, state
7572
end

src/Turing.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ using DynamicPPL:
7878
InitFromParams,
7979
setthreadsafe,
8080
filldist,
81-
arraydist
81+
arraydist,
82+
set_logprob_type!
83+
8284
using StatsBase: predict
8385
using OrderedCollections: OrderedDict
8486
using Libtask: might_produce, @might_produce
@@ -163,6 +165,8 @@ export
163165
fix,
164166
unfix,
165167
OrderedDict, # OrderedCollections
168+
# Log-prob types in accumulators
169+
set_logprob_type!,
166170
# Initialisation strategies for models
167171
InitFromPrior,
168172
InitFromUniform,

src/mcmc/abstractmcmc.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,44 @@ function _convert_initial_params(@nospecialize(_::Any))
3131
throw(ArgumentError(errmsg))
3232
end
3333

34+
"""
35+
find_initial_params_ldf(rng, ldf, init_strategy; max_attempts=1000)
36+
37+
Given a `LogDensityFunction` and an initialization strategy, attempt to find valid initial
38+
parameters by sampling from the initialization strategy and checking that the log density
39+
(and gradient, if available) are finite. If valid parameters are not found after
40+
`max_attempts`, throw an error.
41+
"""
42+
function find_initial_params_ldf(
43+
rng::Random.AbstractRNG,
44+
ldf::DynamicPPL.LogDensityFunction,
45+
init_strategy::DynamicPPL.AbstractInitStrategy;
46+
max_attempts::Int=1000,
47+
)
48+
for attempts in 1:max_attempts
49+
# Get new parameters
50+
x = rand(rng, ldf, init_strategy)
51+
is_valid = if ldf.adtype === nothing
52+
logp = LogDensityProblems.logdensity(ldf, x)
53+
isfinite(logp)
54+
else
55+
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
56+
isfinite(logp) && all(isfinite, grad)
57+
end
58+
59+
# If they're OK, return them
60+
is_valid && return x
61+
62+
attempts == 10 &&
63+
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword"
64+
end
65+
66+
# if we failed to find valid initial parameters, error
67+
return error(
68+
"failed to find valid initial parameters in $(max_attempts) tries. See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues",
69+
)
70+
end
71+
3472
#########################################
3573
# Default definitions for the interface #
3674
#########################################

src/mcmc/external_sampler.jl

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,28 +149,17 @@ function AbstractMCMC.step(
149149
) where {unconstrained}
150150
sampler = sampler_wrapper.sampler
151151

152-
# Initialise varinfo with initial params and link the varinfo if needed.
153-
tfm_strategy = unconstrained ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll()
154-
_, varinfo = DynamicPPL.init!!(rng, model, VarInfo(), initial_params, tfm_strategy)
155-
156-
# We need to extract the vectorised initial_params, because the later call to
157-
# AbstractMCMC.step only sees a `LogDensityModel` which expects `initial_params`
158-
# to be a vector.
159-
initial_params_vector = varinfo[:]
160-
161152
# Construct LogDensityFunction
153+
tfm_strategy = unconstrained ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll()
162154
f = DynamicPPL.LogDensityFunction(
163-
model, DynamicPPL.getlogjoint_internal, varinfo; adtype=sampler_wrapper.adtype
155+
model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=sampler_wrapper.adtype
164156
)
157+
x = find_initial_params_ldf(rng, f, initial_params)
165158

166159
# Then just call `AbstractMCMC.step` with the right arguments.
167160
_, state_inner = if initial_state === nothing
168161
AbstractMCMC.step(
169-
rng,
170-
AbstractMCMC.LogDensityModel(f),
171-
sampler;
172-
initial_params=initial_params_vector,
173-
kwargs...,
162+
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params=x, kwargs...
174163
)
175164

176165
else
@@ -179,7 +168,7 @@ function AbstractMCMC.step(
179168
AbstractMCMC.LogDensityModel(f),
180169
sampler,
181170
initial_state;
182-
initial_params=initial_params_vector,
171+
initial_params=x,
183172
kwargs...,
184173
)
185174
end
@@ -191,7 +180,13 @@ function AbstractMCMC.step(
191180
new_stats = AbstractMCMC.getstats(state_inner)
192181
DynamicPPL.ParamsWithStats(new_parameters, f, new_stats)
193182
end
194-
return (new_transition, TuringState(state_inner, varinfo, new_parameters, f))
183+
184+
# TODO(penelopeysm): this varinfo is only needed for Gibbs. The external sampler itself
185+
# has no use for it. Get rid of this as soon as possible.
186+
vi = DynamicPPL.link!!(VarInfo(model), model)
187+
vi = DynamicPPL.unflatten!!(vi, x)
188+
189+
return (new_transition, TuringState(state_inner, vi, new_parameters, f))
195190
end
196191

197192
function AbstractMCMC.step(

src/mcmc/hmc.jl

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -150,41 +150,10 @@ function AbstractMCMC.sample(
150150
end
151151
end
152152

153-
function find_initial_params(
154-
rng::Random.AbstractRNG,
155-
model::DynamicPPL.Model,
156-
varinfo::DynamicPPL.AbstractVarInfo,
157-
hamiltonian::AHMC.Hamiltonian,
158-
init_strategy::DynamicPPL.AbstractInitStrategy;
159-
max_attempts::Int=1000,
160-
)
161-
varinfo = deepcopy(varinfo) # Don't mutate
162-
163-
for attempts in 1:max_attempts
164-
theta = varinfo[:]
165-
z = AHMC.phasepoint(rng, theta, hamiltonian)
166-
isfinite(z) && return varinfo, z
167-
168-
attempts == 10 &&
169-
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword"
170-
171-
# Resample and try again.
172-
_, varinfo = DynamicPPL.init!!(
173-
rng, model, varinfo, init_strategy, DynamicPPL.LinkAll()
174-
)
175-
end
176-
177-
# if we failed to find valid initial parameters, error
178-
return error(
179-
"failed to find valid initial parameters in $(max_attempts) tries. See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues",
180-
)
181-
end
182-
183-
function Turing.Inference.initialstep(
153+
function AbstractMCMC.step(
184154
rng::AbstractRNG,
185155
model::DynamicPPL.Model,
186-
spl::Hamiltonian,
187-
vi_original::AbstractVarInfo;
156+
spl::Hamiltonian;
188157
# the initial_params kwarg is always passed on from sample(), cf. DynamicPPL
189158
# src/sampler.jl, so we don't need to provide a default value here
190159
initial_params::DynamicPPL.AbstractInitStrategy,
@@ -193,32 +162,19 @@ function Turing.Inference.initialstep(
193162
verbose::Bool=true,
194163
kwargs...,
195164
)
196-
# Transform the samples to unconstrained space and compute the joint log probability.
197-
vi = DynamicPPL.link(vi_original, model)
198-
199-
# Extract parameters.
200-
theta = vi[:]
201-
202-
# Create a Hamiltonian.
203-
metricT = getmetricT(spl)
204-
metric = metricT(length(theta))
165+
# Create a Hamiltonian
205166
ldf = DynamicPPL.LogDensityFunction(
206-
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
167+
model, DynamicPPL.getlogjoint_internal, DynamicPPL.LinkAll(); adtype=spl.adtype
207168
)
169+
metricT = getmetricT(spl)
170+
metric = metricT(LogDensityProblems.dimension(ldf))
208171
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
209172
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
210173
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
211174

212-
# Note that there is already one round of 'initialisation' before we reach this step,
213-
# inside DynamicPPL's `AbstractMCMC.step` implementation. That leads to a possible issue
214-
# that this `find_initial_params` function might override the parameters set by the
215-
# user.
216-
# Luckily for us, `find_initial_params` always checks if the logp and its gradient are
217-
# finite. If it is already finite with the params inside the current `vi`, it doesn't
218-
# attempt to find new ones. This means that the parameters passed to `sample()` will be
219-
# respected instead of being overridden here.
220-
vi, z = find_initial_params(rng, model, vi, hamiltonian, initial_params)
221-
theta = vi[:]
175+
# Find initial values
176+
theta = find_initial_params_ldf(rng, ldf, initial_params)
177+
z = AHMC.phasepoint(rng, theta, hamiltonian)
222178

223179
# Find good eps if not provided one
224180
if iszero(spl.ϵ)
@@ -236,6 +192,12 @@ function Turing.Inference.initialstep(
236192
else
237193
DynamicPPL.ParamsWithStats(theta, ldf, NamedTuple())
238194
end
195+
196+
# TODO(penelopeysm): this varinfo is only needed for Gibbs. HMC itself has no use for
197+
# it. Get rid of this as soon as possible.
198+
vi = DynamicPPL.link!!(VarInfo(model), model)
199+
vi = DynamicPPL.unflatten!!(vi, theta)
200+
239201
state = HMCState(vi, 0, kernel, hamiltonian, z, adaptor, ldf)
240202

241203
return transition, state

0 commit comments

Comments
 (0)