Skip to content

Commit f5e84f4

Browse files
mhaurupenelopeysm
andauthored
Remove dot_tilde pipeline (#804)
* Refactor dot_tilde, work in progress * Restrict dot_tilde to univariate dists on the RHS * Remove tests with multivariates or arrays as RHS of .~ * emove dot_tilde pipeline * Fix a .~ bug * Update HISTORY.md * Fix a tiny test bug * Re-enable some SimpleVarInfo tests * Improve changelog entry * Improve error message * Fix trivial typos * Fix pointwise_logdensity test * Remove pointless check_dot_tilde_rhs method * Add tests for old .~ syntax * Bump Mooncake patch version to v0.4.90 * Bump Mooncake to 0.4.95 --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 4b9665a commit f5e84f4

21 files changed

+198
-1008
lines changed

HISTORY.md

+69-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,74 @@
44

55
**Breaking**
66

7+
### `.~` right hand side must be a univariate distribution
8+
9+
Previously we allowed statements like
10+
11+
```julia
12+
x .~ [Normal(), Gamma()]
13+
```
14+
15+
where the right hand side of a `.~` was an array of distributions, and ones like
16+
17+
```julia
18+
x .~ MvNormal(fill(0.0, 2), I)
19+
```
20+
21+
where the right hand side was a multivariate distribution.
22+
23+
These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as
24+
25+
```julia
26+
x = Array{Float64,3}(undef, 2, 3, 4)
27+
x .~ Normal()
28+
```
29+
30+
The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read.
31+
32+
If the right hand side and the left hand side have the same dimension, one can simply use `~`. Arrays of distributions can be replaced with `product_distribution`. So instead of
33+
34+
```julia
35+
x .~ [Normal(), Gamma()]
36+
x .~ Normal.(y)
37+
x .~ MvNormal(fill(0.0, 2), I)
38+
```
39+
40+
do
41+
42+
```julia
43+
x ~ product_distribution([Normal(), Gamma()])
44+
x ~ product_distribution(Normal.(y))
45+
x ~ MvNormal(fill(0.0, 2), I)
46+
```
47+
48+
This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as
49+
50+
```julia
51+
dists = Normal.(y)
52+
for i in 1:length(dists)
53+
x[i] ~ dists[i]
54+
end
55+
```
56+
57+
Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example,
58+
59+
```julia
60+
x = Array{Float64,3}(undef, 2, 3, 4)
61+
x .~ MvNormal(fill(0, 2), I)
62+
```
63+
64+
should be replaced with something like
65+
66+
```julia
67+
x = Array{Float64,3}(2, 3, 4)
68+
for i in 1:3, j in 1:4
69+
x[:, i, j] ~ MvNormal(fill(0, 2), I)
70+
end
71+
```
72+
73+
This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side.
74+
775
### Remove indexing by samplers
876

977
This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular,
@@ -14,7 +82,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia
1482
- `unflatten` no longer accepts a sampler as an argument
1583
- `eltype(::VarInfo)` no longer accepts a sampler as an argument
1684
- `keys(::VarInfo)` no longer accepts a sampler as an argument
17-
- `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument.
85+
- `VarInfo(::VarInfo, ::Sampler, ::AbstractVector)` no longer accepts the sampler argument.
1886
- `push!!` and `push!` no longer accept samplers or `Selector`s as arguments
1987
- `getgid`, `setgid!`, `updategid!`, `getspace`, and `inspace` no longer exist
2088

Project.toml

+1-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2424
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2525
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2626
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
27-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2827

2928
[weakdeps]
3029
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -33,7 +32,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3332
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3433
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3534
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
36-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3735

3836
[extensions]
3937
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
@@ -42,7 +40,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"]
4240
DynamicPPLJETExt = ["JET"]
4341
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4442
DynamicPPLMooncakeExt = ["Mooncake"]
45-
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4643

4744
[compat]
4845
ADTypes = "1"
@@ -65,10 +62,9 @@ LogDensityProblems = "2"
6562
LogDensityProblemsAD = "1.7.0"
6663
MCMCChains = "6"
6764
MacroTools = "0.5.6"
68-
Mooncake = "0.4.59"
65+
Mooncake = "0.4.95"
6966
OrderedCollections = "1"
7067
Random = "1.6"
7168
Requires = "1"
7269
Test = "1.6"
73-
ZygoteRules = "0.2"
7470
julia = "1.10"

docs/src/api.md

-2
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,8 @@ DynamicPPL.Experimental.is_suitable_varinfo
440440

441441
```@docs
442442
tilde_assume
443-
dot_tilde_assume
444443
```
445444

446445
```@docs
447446
tilde_observe
448-
dot_tilde_observe
449447
```

ext/DynamicPPLZygoteRulesExt.jl

-25
This file was deleted.

src/DynamicPPL.jl

-4
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,9 @@ export AbstractVarInfo,
9898
PrefixContext,
9999
ConditionContext,
100100
assume,
101-
dot_assume,
102101
observe,
103-
dot_observe,
104102
tilde_assume,
105103
tilde_observe,
106-
dot_tilde_assume,
107-
dot_tilde_observe,
108104
# Pseudo distributions
109105
NamedDist,
110106
NoDist,

src/compiler.jl

+39-52
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,16 @@ Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` other
161161
"""
162162
isliteral(e) = false
163163
isliteral(::Number) = true
164-
isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args)
164+
function isliteral(e::Expr)
165+
# In the special case that the expression is of the form `abc[blahblah]`, we consider it
166+
# to be a literal if `abc` is a literal. This is necessary for cases like
167+
# [1.0, 2.0][idx...] ~ Normal()
168+
# which are generated when turning `.~` expressions into loops over `~` expressions.
169+
if e.head == :ref
170+
return isliteral(e.args[1])
171+
end
172+
return !isempty(e.args) && all(isliteral, e.args)
173+
end
165174

166175
"""
167176
check_tilde_rhs(x)
@@ -172,7 +181,7 @@ Check if the right-hand side `x` of a `~` is a `Distribution` or an array of
172181
function check_tilde_rhs(@nospecialize(x))
173182
return throw(
174183
ArgumentError(
175-
"the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s",
184+
"the right-hand side of a `~` must be a `Distribution`, an array of `Distribution`s, or a submodel",
176185
),
177186
)
178187
end
@@ -184,6 +193,27 @@ function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
184193
return Sampleable{typeof(model),AutoPrefix}(model)
185194
end
186195

196+
"""
197+
check_dot_tilde_rhs(x)
198+
199+
Check if the right-hand side `x` of a `.~` is a `UnivariateDistribution`, then return `x`.
200+
"""
201+
function check_dot_tilde_rhs(@nospecialize(x))
202+
return throw(
203+
ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`")
204+
)
205+
end
206+
function check_dot_tilde_rhs(::AbstractArray{<:Distribution})
207+
msg = """
208+
As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \
209+
Please use `product_distribution` instead, or write a loop if necessary. \
210+
See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \
211+
details.\
212+
"""
213+
return throw(ArgumentError(msg))
214+
end
215+
check_dot_tilde_rhs(x::UnivariateDistribution) = x
216+
187217
"""
188218
unwrap_right_vn(right, vn)
189219
@@ -356,11 +386,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
356386
args_dottilde = getargs_dottilde(expr)
357387
if args_dottilde !== nothing
358388
L, R = args_dottilde
359-
return Base.remove_linenums!(
360-
generate_dot_tilde(
361-
generate_mainbody!(mod, found, L, warn),
362-
generate_mainbody!(mod, found, R, warn),
363-
),
389+
return generate_mainbody!(
390+
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
364391
)
365392
end
366393

@@ -487,56 +514,16 @@ end
487514
Generate the expression that replaces `left .~ right` in the model body.
488515
"""
489516
function generate_dot_tilde(left, right)
490-
isliteral(left) && return generate_tilde_literal(left, right)
491-
492-
# Otherwise it is determined by the model or its value,
493-
# if the LHS represents an observation
494-
@gensym vn isassumption value
517+
@gensym dist left_axes idx
495518
return quote
496-
$vn = $(DynamicPPL.resolve_varnames)(
497-
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right
498-
)
499-
$isassumption = $(DynamicPPL.isassumption(left, vn))
500-
if $(DynamicPPL.isfixed(left, vn))
501-
$left .= $(DynamicPPL.getfixed_nested)(__context__, $vn)
502-
elseif $isassumption
503-
$(generate_dot_tilde_assume(left, right, vn))
504-
else
505-
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
506-
if !$(DynamicPPL.inargnames)($vn, __model__)
507-
$left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn)
508-
end
509-
510-
$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
511-
__context__,
512-
$(DynamicPPL.check_tilde_rhs)($right),
513-
$(maybe_view(left)),
514-
$vn,
515-
__varinfo__,
516-
)
517-
$value
519+
$dist = DynamicPPL.check_dot_tilde_rhs($right)
520+
$left_axes = axes($left)
521+
for $idx in Iterators.product($left_axes...)
522+
$left[$idx...] ~ $dist
518523
end
519524
end
520525
end
521526

522-
function generate_dot_tilde_assume(left, right, vn)
523-
# We don't need to use `Setfield.@set` here since
524-
# `.=` is always going to be inplace + needs `left` to
525-
# be something that supports `.=`.
526-
@gensym value
527-
return quote
528-
$value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)(
529-
__context__,
530-
$(DynamicPPL.unwrap_right_left_vns)(
531-
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
532-
)...,
533-
__varinfo__,
534-
)
535-
$left .= $value
536-
$value
537-
end
538-
end
539-
540527
# Note that we cannot use `MacroTools.isdef` because
541528
# of https://github.com/FluxML/MacroTools.jl/issues/154.
542529
"""

0 commit comments

Comments
 (0)