Skip to content

Commit 2e6ad77

Browse files
authored
Try #372:
2 parents d222316 + 4492504 commit 2e6ad77

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

src/compiler.jl

+13-5
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ function unwrap_right_left_vns(
175175
return unwrap_right_left_vns(right, left, vns)
176176
end
177177

178+
resolve_varnames(vn::VarName, _) = vn
179+
resolve_varnames(vn::VarName, dist::NamedDist) = dist.name
180+
178181
#################
179182
# Main Compiler #
180183
#################
@@ -379,16 +382,19 @@ function generate_tilde(left, right)
379382

380383
# Otherwise it is determined by the model or its value,
381384
# if the LHS represents an observation
382-
@gensym vn isassumption value
385+
@gensym vn isassumption value dist
383386

384387
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
385388
# that in DynamicPPL we the entire function body. Instead we should be
386389
# more selective with our escape. Until that's the case, we remove them all.
387390
return quote
388-
$vn = $(AbstractPPL.drop_escape(varname(left)))
391+
$dist = $right
392+
$vn = $(DynamicPPL.resolve_varnames)(
393+
$(AbstractPPL.drop_escape(varname(left))), $dist
394+
)
389395
$isassumption = $(DynamicPPL.isassumption(left, vn))
390396
if $isassumption
391-
$(generate_tilde_assume(left, right, vn))
397+
$(generate_tilde_assume(left, dist, vn))
392398
else
393399
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
394400
if !$(DynamicPPL.inargnames)($vn, __model__)
@@ -397,7 +403,7 @@ function generate_tilde(left, right)
397403

398404
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
399405
__context__,
400-
$(DynamicPPL.check_tilde_rhs)($right),
406+
$(DynamicPPL.check_tilde_rhs)($dist),
401407
$(maybe_view(left)),
402408
$vn,
403409
__varinfo__,
@@ -442,7 +448,9 @@ function generate_dot_tilde(left, right)
442448
# if the LHS represents an observation
443449
@gensym vn isassumption value
444450
return quote
445-
$vn = $(AbstractPPL.drop_escape(varname(left)))
451+
$vn = $(DynamicPPL.resolve_varnames)(
452+
$(AbstractPPL.drop_escape(varname(left))), $right
453+
)
446454
$isassumption = $(DynamicPPL.isassumption(left, vn))
447455
if $isassumption
448456
$(generate_dot_tilde_assume(left, right, vn))

src/distribution_wrappers.jl

+11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@ end
1313

1414
NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())
1515

16+
Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
17+
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
18+
return Distributions.logpdf(dist.dist, x)
19+
end
20+
function Distributions.loglikelihood(dist::NamedDist, x::Real)
21+
return Distributions.loglikelihood(dist.dist, x)
22+
end
23+
function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
24+
return Distributions.loglikelihood(dist.dist, x)
25+
end
26+
1627
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
1728
Distribution{variate,support}
1829
dist::Td

test/compiler.jl

+12
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,18 @@ end
312312
@test vi2.metadata.y.vns[1] == @varname(y[2][:, 1])
313313
@test haskey(vi3.metadata, :y)
314314
@test vi3.metadata.y.vns[1] == @varname(y[1])
315+
316+
# Conditioning
317+
f1_c = f1() | (y=1,)
318+
f2_c = f2() | NamedTuple((Symbol(@varname(y[2][:, 1])) => 1,))
319+
f3_c = f3() | NamedTuple((Symbol(@varname(y[1])) => 1,))
320+
@test f1_c() == 1
321+
# TODO(torfjelde): We need conditioning for `Dict`.
322+
@test_broken f2_c() == 1
323+
@test_broken f3_c() == 1
324+
@test_broken getlogp(VarInfo(f1_c)) ==
325+
getlogp(VarInfo(f2_c)) ==
326+
getlogp(VarInfo(f3_c))
315327
end
316328
@testset "custom tilde" begin
317329
@model demo() = begin

0 commit comments

Comments
 (0)