Skip to content

Commit 0810e14

Browse files
Clean up varinfo get/set functions (#853)
* Clean up varinfo get/set functions * Re-add but deprecate VarInfo(::VarInfo, ::AbstractVector) * Increase atol for truncated bijector test * Update test/simple_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e4fa7f2 commit 0810e14

8 files changed

+64
-100
lines changed

HISTORY.md

+11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
# DynamicPPL Changelog
22

3+
## 0.35.5
4+
5+
Several internal methods have been removed:
6+
7+
- `DynamicPPL.getall(vi::AbstractVarInfo)` has been removed. You can directly replace this with `getindex_internal(vi, Colon())`.
8+
- `DynamicPPL.setall!(vi::AbstractVarInfo, values)` has been removed. Rewrite the calling function to not assume mutation and use `unflatten(vi, values)` instead.
9+
- `DynamicPPL.replace_values(md::Metadata, values)` and `DynamicPPL.replace_values(nt::NamedTuple, values)` (where the `nt` is a NamedTuple of Metadatas) have been removed. Use `DynamicPPL.unflatten_metadata` as a direct replacement.
10+
- `DynamicPPL.set_values!!(vi::AbstractVarInfo, values)` has been renamed to `DynamicPPL.set_initial_values(vi::AbstractVarInfo, values)`; it also no longer mutates the varinfo argument.
11+
12+
The **exported** method `VarInfo(vi::VarInfo, values)` has been deprecated, and will be removed in the next minor version. You can replace this directly with `unflatten(vi, values)` instead.
13+
314
## 0.35.4
415

516
Fixed a type instability in an implementation of `with_logabsdet_jacobian`, which resulted in the log-jacobian returned being an Int in some cases and a Float in others.

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.35.4"
3+
version = "0.35.5"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/abstract_varinfo.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,16 @@ Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector)
162162
"""
163163
getindex_internal(vi::AbstractVarInfo, vn::VarName)
164164
getindex_internal(vi::AbstractVarInfo, vns::Vector{<:VarName})
165+
getindex_internal(vi::AbstractVarInfo, ::Colon)
165166
166-
Return the current value(s) of `vn` (`vns`) in `vi` as represented internally in `vi`.
167+
Return the internal value of the varname `vn`, varnames `vns`, or all varnames
168+
in `vi` respectively. The internal value is the value of the variables that is
169+
stored in the varinfo object; this may be the actual realisation of the random
170+
variable (i.e. the value sampled from the distribution), or it may have been
171+
transformed to Euclidean space, depending on whether the varinfo was linked.
172+
173+
See https://turinglang.org/docs/developers/transforms/dynamicppl/ for more
174+
information on how transformed variables are stored in DynamicPPL.
167175
168176
See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref)
169177
"""

src/sampler.jl

+21-6
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,21 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
151151
"""
152152
initialsampler(spl::Sampler) = SampleFromPrior()
153153

154-
function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector)
154+
"""
155+
set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
156+
set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
157+
158+
Take the values inside `initial_params`, replace the corresponding values in
159+
the given VarInfo object, and return a new VarInfo object with the updated values.
160+
161+
This differs from `DynamicPPL.unflatten` in two ways:
162+
163+
1. It works with `NamedTuple` arguments.
164+
2. For the `AbstractVector` method, if any of the elements are missing, it will not
165+
overwrite the original value in the VarInfo (it will just use the original
166+
value instead).
167+
"""
168+
function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
155169
throw(
156170
ArgumentError(
157171
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
@@ -160,7 +174,7 @@ function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector)
160174
)
161175
end
162176

163-
function set_values!!(
177+
function set_initial_values(
164178
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
165179
)
166180
flattened_param_vals = varinfo[:]
@@ -180,11 +194,12 @@ function set_values!!(
180194
end
181195

182196
# Update in `varinfo`.
183-
setall!(varinfo, flattened_param_vals)
184-
return varinfo
197+
new_varinfo = unflatten(varinfo, flattened_param_vals)
198+
return new_varinfo
185199
end
186200

187-
function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple)
201+
function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
202+
varinfo = deepcopy(varinfo)
188203
vars_in_varinfo = keys(varinfo)
189204
for v in keys(initial_params)
190205
vn = VarName{v}()
@@ -219,7 +234,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Mod
219234
end
220235

221236
# Set the values in `vi`.
222-
vi = set_values!!(vi, initial_params)
237+
vi = set_initial_values(vi, initial_params)
223238

224239
# `invlink` if needed.
225240
if linked

src/varinfo.jl

+15-87
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,14 @@ const TypedVarInfo = VarInfo{<:NamedTuple}
100100
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
101101
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
102102
}
103+
# TODO: Remove this
104+
@deprecate VarInfo(vi::VarInfo, x::AbstractVector) unflatten(vi, x)
103105

104106
# NOTE: This is kind of weird, but it effectively preserves the "old"
105107
# behavior where we're allowed to call `link!` on the same `VarInfo`
106108
# multiple times.
107109
transformation(::VarInfo) = DynamicTransformation()
108110

109-
# TODO(mhauru) Isn't this the same as unflatten and/or replace_values?
110-
function VarInfo(old_vi::VarInfo, x::AbstractVector)
111-
md = replace_values(old_vi.metadata, x)
112-
return VarInfo(
113-
md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))
114-
)
115-
end
116-
117111
# No-op if we're already working with a `VarNamedVector`.
118112
metadata_to_varnamedvector(vnv::VarNamedVector) = vnv
119113
function metadata_to_varnamedvector(md::Metadata)
@@ -243,9 +237,8 @@ end
243237
return :($(exprs...),)
244238
end
245239

246-
# For Metadata unflatten and replace_values are the same. For VarNamedVector they are not.
247240
function unflatten_metadata(md::Metadata, x::AbstractVector)
248-
return replace_values(md, x)
241+
return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags)
249242
end
250243

251244
unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)
@@ -255,31 +248,6 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext
255248
return VarInfo(rng, model, SampleFromPrior(), context)
256249
end
257250

258-
function replace_values(metadata::Metadata, x)
259-
return Metadata(
260-
metadata.idcs,
261-
metadata.vns,
262-
metadata.ranges,
263-
x,
264-
metadata.dists,
265-
metadata.orders,
266-
metadata.flags,
267-
)
268-
end
269-
270-
@generated function replace_values(metadata::NamedTuple{names}, x) where {names}
271-
exprs = []
272-
offset = :(0)
273-
for f in names
274-
mdf = :(metadata.$f)
275-
len = :(sum(length, $mdf.ranges))
276-
push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)])))
277-
offset = :($offset + $len)
278-
end
279-
length(exprs) == 0 && return :(NamedTuple())
280-
return :($(exprs...),)
281-
end
282-
283251
####
284252
#### Internal functions
285253
####
@@ -652,10 +620,20 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi,
652620
# what a bijector would result in, even if the input is a view (`SubArray`).
653621
# TODO(torfjelde): An alternative is to implement `view` directly instead.
654622
getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn))
655-
656623
function getindex_internal(vi::VarInfo, vns::Vector{<:VarName})
657624
return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns)
658625
end
626+
getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon())
627+
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference.
628+
# See for example https://github.com/JuliaLang/julia/pull/46381.
629+
function getindex_internal(vi::TypedVarInfo, ::Colon)
630+
return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata))
631+
end
632+
function getindex_internal(md::Metadata, ::Colon)
633+
return mapreduce(
634+
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0)
635+
)
636+
end
659637

660638
"""
661639
setval!(vi::VarInfo, val, vn::VarName)
@@ -672,56 +650,6 @@ function setval!(md::Metadata, val, vn::VarName)
672650
return md.vals[getrange(md, vn)] = tovec(val)
673651
end
674652

675-
"""
676-
getall(vi::VarInfo)
677-
678-
Return the values of all the variables in `vi`.
679-
680-
The values may or may not be transformed to Euclidean space.
681-
"""
682-
getall(vi::VarInfo) = getall(vi.metadata)
683-
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference.
684-
# See for example https://github.com/JuliaLang/julia/pull/46381.
685-
getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata))
686-
function getall(md::Metadata)
687-
return mapreduce(
688-
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0)
689-
)
690-
end
691-
getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon())
692-
693-
"""
694-
setall!(vi::VarInfo, val)
695-
696-
Set the values of all the variables in `vi` to `val`.
697-
698-
The values may or may not be transformed to Euclidean space.
699-
"""
700-
setall!(vi::VarInfo, val) = _setall!(vi.metadata, val)
701-
702-
function _setall!(metadata::Metadata, val)
703-
for r in metadata.ranges
704-
metadata.vals[r] .= val[r]
705-
end
706-
end
707-
function _setall!(vnv::VarNamedVector, val)
708-
# TODO(mhauru) Do something more efficient here.
709-
for i in 1:length_internal(vnv)
710-
setindex_internal!(vnv, val[i], i)
711-
end
712-
end
713-
@generated function _setall!(metadata::NamedTuple{names}, val) where {names}
714-
expr = Expr(:block)
715-
start = :(1)
716-
for f in names
717-
length = :(sum(length, metadata.$f.ranges))
718-
finish = :($start + $length - 1)
719-
push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length)))
720-
start = :($start + $length)
721-
end
722-
return expr
723-
end
724-
725653
function settrans!!(vi::VarInfo, trans::Bool, vn::VarName)
726654
settrans!!(getmetadata(vi, vn), trans, vn)
727655
return vi
@@ -2114,7 +2042,7 @@ function _setval_and_resample_kernel!(
21142042
end
21152043

21162044
values_as(vi::VarInfo) = vi.metadata
2117-
values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi))
2045+
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
21182046
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})
21192047
iter = values_from_metadata(vi.metadata)
21202048
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))

test/model.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
163163
Random.seed!(100 + i)
164164
vi = VarInfo()
165165
model(Random.default_rng(), vi, sampler)
166-
vals = DynamicPPL.getall(vi)
166+
vals = vi[:]
167167

168168
Random.seed!(100 + i)
169169
vi = VarInfo()
170170
model(Random.default_rng(), vi, sampler)
171-
@test DynamicPPL.getall(vi) == vals
171+
@test vi[:] == vals
172172
end
173173
end
174174
end
@@ -240,7 +240,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
240240
for i in 1:10
241241
# Sample with large variations.
242242
r_raw = randn(length(vi[:])) * 10
243-
DynamicPPL.setall!(vi, r_raw)
243+
vi = DynamicPPL.unflatten(vi, r_raw)
244244
@test vi[@varname(m)] == r_raw[1]
245245
@test vi[@varname(x)] != r_raw[2]
246246
model(vi)

test/simple_varinfo.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@
257257

258258
# `getlogp` should be equal to the logjoint with log-absdet-jac correction.
259259
lp = getlogp(svi)
260-
@test lp lp_true
260+
# needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375
261+
@test lp lp_true atol = 1.2e-5
261262
end
262263
end
263264
end

test/test_util.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
end
99
const gdemo_default = gdemo_d()
1010

11+
# TODO(penelopeysm): Remove this (and also test/compat/ad.jl)
1112
function test_model_ad(model, logp_manual)
1213
vi = VarInfo(model)
13-
x = DynamicPPL.getall(vi)
14+
x = vi[:]
1415

1516
# Log probabilities using the model.
1617
= DynamicPPL.LogDensityFunction(model, vi)

0 commit comments

Comments
 (0)