Skip to content

Commit 4b9665a

Browse files
mhaurusunxd3
andauthored
Remove samplers from VarInfo - Selectors and GIDs (#808)
* Remove Selectors and Gibbs IDs * Remove getspace * Remove a dead VNV method --------- Co-authored-by: Xianda Sun <[email protected]>
1 parent 29f3760 commit 4b9665a

15 files changed

+40
-226
lines changed

HISTORY.md

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ This release removes the feature of `VarInfo` where it kept track of which varia
1515
- `eltype(::VarInfo)` no longer accepts a sampler as an argument
1616
- `keys(::VarInfo)` no longer accepts a sampler as an argument
1717
- `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument.
18+
- `push!!` and `push!` no longer accept samplers or `Selector`s as arguments
19+
- `getgid`, `setgid!`, `updategid!`, `getspace`, and `inspace` no longer exist
1820

1921
### Reverse prefixing order
2022

docs/src/api.md

-7
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,6 @@ unset_flag!
289289
is_flagged
290290
```
291291

292-
For Gibbs sampling the following functions were added.
293-
294-
```@docs
295-
setgid!
296-
updategid!
297-
```
298-
299292
The following functions were used for sequential Monte Carlo methods.
300293

301294
```@docs

ext/DynamicPPLChainRulesCoreExt.jl

+1-9
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,7 @@ end
1010

1111
# See https://github.com/TuringLang/Turing.jl/issues/1199
1212
ChainRulesCore.@non_differentiable BangBang.push!!(
13-
vi::DynamicPPL.VarInfo,
14-
vn::DynamicPPL.VarName,
15-
r,
16-
dist::Distributions.Distribution,
17-
gidset::Set{DynamicPPL.Selector},
18-
)
19-
20-
ChainRulesCore.@non_differentiable DynamicPPL.updategid!(
21-
vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, spl::DynamicPPL.Sampler
13+
vi::DynamicPPL.VarInfo, vn::DynamicPPL.VarName, r, dist::Distributions.Distribution
2214
)
2315

2416
# No need + causes issues for some AD backends, e.g. Zygote.

src/DynamicPPL.jl

-12
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ export AbstractVarInfo,
6363
is_flagged,
6464
set_flag!,
6565
unset_flag!,
66-
setgid!,
67-
updategid!,
6866
setorder!,
6967
istrans,
7068
link,
@@ -74,7 +72,6 @@ export AbstractVarInfo,
7472
values_as,
7573
# VarName (reexport from AbstractPPL)
7674
VarName,
77-
inspace,
7875
subsumes,
7976
@varname,
8077
# Compiler
@@ -152,9 +149,6 @@ macro prob_str(str)
152149
))
153150
end
154151

155-
# Used here and overloaded in Turing
156-
function getspace end
157-
158152
"""
159153
AbstractVarInfo
160154
@@ -166,14 +160,8 @@ See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
166160
"""
167161
abstract type AbstractVarInfo <: AbstractModelTrace end
168162

169-
const LEGACY_WARNING = """
170-
!!! warning
171-
This method is considered legacy, and is likely to be deprecated in the future.
172-
"""
173-
174163
# Necessary forward declarations
175164
include("utils.jl")
176-
include("selector.jl")
177165
include("chains.jl")
178166
include("model.jl")
179167
include("sampler.jl")

src/abstract_varinfo.jl

-46
Original file line numberDiff line numberDiff line change
@@ -169,51 +169,6 @@ See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@re
169169
"""
170170
function getindex_internal end
171171

172-
"""
173-
push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution)
174-
175-
Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to
176-
the `VarInfo` `vi`, mutating if it makes sense.
177-
"""
178-
function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution)
179-
return BangBang.push!!(vi, vn, r, dist, Set{Selector}([]))
180-
end
181-
182-
"""
183-
push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler)
184-
185-
Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl`
186-
from a distribution `dist` to `VarInfo` `vi`, if it makes sense.
187-
188-
The sampler is passed here to invalidate its cache where defined.
189-
190-
$(LEGACY_WARNING)
191-
"""
192-
function BangBang.push!!(
193-
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler
194-
)
195-
return BangBang.push!!(vi, vn, r, dist, spl.selector)
196-
end
197-
function BangBang.push!!(
198-
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler
199-
)
200-
return BangBang.push!!(vi, vn, r, dist)
201-
end
202-
203-
"""
204-
push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector)
205-
206-
Push a new random variable `vn` with a sampled value `r` sampled with a sampler of
207-
selector `gid` from a distribution `dist` to `VarInfo` `vi`.
208-
209-
$(LEGACY_WARNING)
210-
"""
211-
function BangBang.push!!(
212-
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector
213-
)
214-
return BangBang.push!!(vi, vn, r, dist, Set([gid]))
215-
end
216-
217172
@doc """
218173
empty!!(vi::AbstractVarInfo)
219174
@@ -768,7 +723,6 @@ end
768723
# Legacy code that is currently overloaded for the sake of simplicity.
769724
# TODO: Remove when possible.
770725
increment_num_produce!(::AbstractVarInfo) = nothing
771-
setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing
772726

773727
"""
774728
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])

src/context_implementations.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,11 @@ function assume(
239239
r = init(rng, dist, sampler)
240240
if istrans(vi)
241241
f = to_linked_internal_transform(vi, vn, dist)
242-
push!!(vi, vn, f(r), dist, sampler)
242+
push!!(vi, vn, f(r), dist)
243243
# By default `push!!` sets the transformed flag to `false`.
244244
settrans!!(vi, true, vn)
245245
else
246-
push!!(vi, vn, r, dist, sampler)
246+
push!!(vi, vn, r, dist)
247247
end
248248
end
249249

@@ -466,11 +466,11 @@ function get_and_set_val!(
466466
vn = vns[i]
467467
if istrans(vi)
468468
ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i])
469-
push!!(vi, vn, ri_linked, dist, spl)
469+
push!!(vi, vn, ri_linked, dist)
470470
# `push!!` sets the trans-flag to `false` by default.
471471
settrans!!(vi, true, vn)
472472
else
473-
push!!(vi, vn, r[:, i], dist, spl)
473+
push!!(vi, vn, r[:, i], dist)
474474
end
475475
end
476476
end
@@ -513,14 +513,14 @@ function get_and_set_val!(
513513
# 2. Define an anonymous function which returns `nothing`, which
514514
# we then broadcast. This will allocate a vector of `nothing` though.
515515
if istrans(vi)
516-
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,))
516+
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists)
517517
# NOTE: Need to add the correction.
518518
# FIXME: This is not great.
519519
acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
520520
# `push!!` sets the trans-flag to `false` by default.
521521
settrans!!.((vi,), true, vns)
522522
else
523-
push!!.((vi,), vns, r, dists, (spl,))
523+
push!!.((vi,), vns, r, dists)
524524
end
525525
end
526526
return r

src/sampler.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ Sampling algorithm that samples unobserved random variables from their prior dis
1818
"""
1919
struct SampleFromPrior <: AbstractSampler end
2020

21-
getspace(::Union{SampleFromPrior,SampleFromUniform}) = ()
22-
2321
# Initializations.
2422
init(rng, dist, ::SampleFromPrior) = rand(rng, dist)
2523
function init(rng, dist, ::SampleFromUniform)
@@ -31,6 +29,8 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
3129
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
3230
end
3331

32+
# TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`?
33+
# (Selector has been removed).
3434
"""
3535
Sampler{T}
3636
@@ -47,12 +47,7 @@ By default, values are sampled from the prior.
4747
"""
4848
struct Sampler{T} <: AbstractSampler
4949
alg::T
50-
selector::Selector # Can we remove it?
51-
# TODO: add space such that we can integrate existing external samplers in DynamicPPL
5250
end
53-
Sampler(alg) = Sampler(alg, Selector())
54-
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
55-
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)
5651

5752
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
5853
function AbstractMCMC.step(

src/selector.jl

-13
This file was deleted.

src/simple_varinfo.jl

+6-22
Original file line numberDiff line numberDiff line change
@@ -374,42 +374,26 @@ end
374374

375375
# `NamedTuple`
376376
function BangBang.push!!(
377-
vi::SimpleVarInfo{<:NamedTuple},
378-
vn::VarName{sym,typeof(identity)},
379-
value,
380-
dist::Distribution,
381-
gidset::Set{Selector},
377+
vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution
382378
) where {sym}
383379
return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,)))
384380
end
385381
function BangBang.push!!(
386-
vi::SimpleVarInfo{<:NamedTuple},
387-
vn::VarName{sym},
388-
value,
389-
dist::Distribution,
390-
gidset::Set{Selector},
382+
vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution
391383
) where {sym}
392384
return Accessors.@set vi.values = set!!(vi.values, vn, value)
393385
end
394386

395387
# `AbstractDict`
396388
function BangBang.push!!(
397-
vi::SimpleVarInfo{<:AbstractDict},
398-
vn::VarName,
399-
value,
400-
dist::Distribution,
401-
gidset::Set{Selector},
389+
vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution
402390
)
403391
vi.values[vn] = value
404392
return vi
405393
end
406394

407395
function BangBang.push!!(
408-
vi::SimpleVarInfo{<:VarNamedVector},
409-
vn::VarName,
410-
value,
411-
dist::Distribution,
412-
gidset::Set{Selector},
396+
vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution
413397
)
414398
# The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For
415399
# SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not.
@@ -483,7 +467,7 @@ function assume(
483467
value = init(rng, dist, sampler)
484468
# Transform if we're working in unconstrained space.
485469
value_raw = to_maybe_linked_internal(vi, vn, dist, value)
486-
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
470+
vi = BangBang.push!!(vi, vn, value_raw, dist)
487471
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
488472
end
489473

@@ -550,7 +534,7 @@ function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
550534
end
551535

552536
istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
553-
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
537+
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)
554538
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
555539

556540
islinked(vi::SimpleVarInfo) = istrans(vi)

src/threadsafe.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ end
5757

5858
has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo)
5959

60-
function BangBang.push!!(
61-
vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
62-
)
63-
return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset)
60+
function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution)
61+
return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist)
6462
end
6563

6664
get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo)
@@ -70,9 +68,6 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n
7068

7169
syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
7270

73-
function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
74-
return setgid!(vi.varinfo, gid, vn)
75-
end
7671
setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index)
7772
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
7873

0 commit comments

Comments
 (0)