Skip to content

Commit 903e0c6

Browse files
authored
Add varname tests from DPPL + format repo (#111)
* Add varname tests from DPPL cf. TuringLang/DynamicPPL.jl#737 * Format * Format readme
1 parent 1c5408b commit 903e0c6

File tree

7 files changed

+242
-191
lines changed

7 files changed

+242
-191
lines changed

README.md

+99-111
Large diffs are not rendered by default.

docs/make.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ using Documenter
22
using AbstractPPL
33

44
# Doctest setup
5-
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive = true)
5+
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true)
66

77
makedocs(;
8-
sitename = "AbstractPPL",
9-
modules = [AbstractPPL],
10-
pages = ["Home" => "index.md", "API" => "api.md"],
11-
checkdocs = :exports,
12-
doctest = false,
8+
sitename="AbstractPPL",
9+
modules=[AbstractPPL],
10+
pages=["Home" => "index.md", "API" => "api.md"],
11+
checkdocs=:exports,
12+
doctest=false,
1313
)

src/AbstractPPL.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@ export VarName,
1616
varname_to_string,
1717
string_to_varname
1818

19-
2019
# Abstract model functions
21-
export AbstractProbabilisticProgram, condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!!
20+
export AbstractProbabilisticProgram,
21+
condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!!
2222

2323
# Abstract traces
2424
export AbstractModelTrace
2525

26-
2726
include("varname.jl")
2827
include("abstractmodeltrace.jl")
2928
include("abstractprobprog.jl")

src/abstractprobprog.jl

-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using AbstractMCMC
22
using DensityInterface
33
using Random
44

5-
65
"""
76
AbstractProbabilisticProgram
87
@@ -12,7 +11,6 @@ abstract type AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel end
1211

1312
DensityInterface.DensityKind(::AbstractProbabilisticProgram) = HasDensity()
1413

15-
1614
"""
1715
logdensityof(model, trace)
1816
@@ -26,7 +24,6 @@ probability theory.
2624
"""
2725
DensityInterface.logdensityof(::AbstractProbabilisticProgram, ::AbstractModelTrace)
2826

29-
3027
"""
3128
decondition(conditioned_model)
3229
@@ -43,7 +40,6 @@ should hold for models `m` with conditioned variables `obs`.
4340
"""
4441
function decondition end
4542

46-
4743
"""
4844
condition(model, observations)
4945
@@ -84,7 +80,6 @@ should hold for any model `m` and parameters `params`.
8480
"""
8581
function fix end
8682

87-
8883
"""
8984
unfix(model)
9085

src/varname.jl

+77-38
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ struct VarName{sym,T}
3636

3737
function VarName{sym}(optic=identity) where {sym}
3838
if !is_static_optic(typeof(optic))
39-
throw(ArgumentError("attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))"))
39+
throw(
40+
ArgumentError(
41+
"attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))",
42+
),
43+
)
4044
end
4145
return new{sym,typeof(optic)}(optic)
4246
end
@@ -168,7 +172,7 @@ end
168172

169173
function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T}
170174
print(io, getsym(vn))
171-
_show_optic(io, getoptic(vn))
175+
return _show_optic(io, getoptic(vn))
172176
end
173177

174178
# modified from https://github.com/JuliaObjects/Accessors.jl/blob/01528a81fdf17c07436e1f3d99119d3f635e4c26/src/sugar.jl#L502
@@ -181,7 +185,7 @@ function _show_optic(io::IO, optic)
181185
print(io, "")
182186
end
183187
shortstr = reduce(_shortstring, inner; init="")
184-
print(io, shortstr)
188+
return print(io, shortstr)
185189
end
186190

187191
_shortstring(prev, o::IndexLens) = "$prev[$(join(map(prettify_index, o.indices), ", "))]"
@@ -207,7 +211,6 @@ Symbol("x[1][:]")
207211
"""
208212
Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol
209213

210-
211214
"""
212215
inspace(vn::Union{VarName, Symbol}, space::Tuple)
213216
@@ -244,7 +247,6 @@ inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space)
244247
_in(vn::VarName, s::Symbol) = getsym(vn) == s
245248
_in(vn::VarName, s::VarName) = subsumes(s, vn)
246249

247-
248250
"""
249251
subsumes(u::VarName, v::VarName)
250252
@@ -297,8 +299,9 @@ subsumes(::typeof(identity), ::typeof(identity)) = true
297299
subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true
298300
subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false
299301

300-
subsumes(t::ComposedOptic, u::ComposedOptic) =
301-
subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
302+
function subsumes(t::ComposedOptic, u::ComposedOptic)
303+
return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
304+
end
302305

303306
# If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a
304307
# leaf of the "lens-tree".
@@ -317,11 +320,12 @@ subsumes(t::PropertyLens, u::PropertyLens) = false
317320
# FIXME: Does not support `DynamicIndexLens`.
318321
# FIXME: Does not correctly handle cases such as `subsumes(x, x[:])`
319322
# (but neither did old implementation).
320-
subsumes(
323+
function subsumes(
321324
t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
322-
u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}
323-
) = subsumes_indices(t, u)
324-
325+
u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
326+
)
327+
return subsumes_indices(t, u)
328+
end
325329

326330
"""
327331
subsumedby(t, u)
@@ -444,7 +448,6 @@ subsumes_index(i::Colon, j) = true
444448
subsumes_index(i::AbstractVector, j) = issubset(j, i)
445449
subsumes_index(i, j) = i == j
446450

447-
448451
"""
449452
ConcretizedSlice(::Base.Slice)
450453
@@ -455,10 +458,13 @@ struct ConcretizedSlice{T,R} <: AbstractVector{T}
455458
range::R
456459
end
457460

458-
ConcretizedSlice(s::Base.Slice{R}) where {R} = ConcretizedSlice{eltype(s.indices),R}(s.indices)
461+
function ConcretizedSlice(s::Base.Slice{R}) where {R}
462+
return ConcretizedSlice{eltype(s.indices),R}(s.indices)
463+
end
459464
Base.show(io::IO, s::ConcretizedSlice) = print(io, ":")
460-
Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) =
461-
print(io, "ConcretizedSlice(", s.range, ")")
465+
function Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice)
466+
return print(io, "ConcretizedSlice(", s.range, ")")
467+
end
462468
Base.size(s::ConcretizedSlice) = size(s.range)
463469
Base.iterate(s::ConcretizedSlice, state...) = Base.iterate(s.range, state...)
464470
Base.collect(s::ConcretizedSlice) = collect(s.range)
@@ -480,8 +486,9 @@ The only purpose of this are special cases like `:`, which we want to avoid beco
480486
`ConcretizedSlice` based on the `lowered_index`, just what you'd get with an explicit `begin:end`
481487
"""
482488
reconcretize_index(original_index, lowered_index) = lowered_index
483-
reconcretize_index(original_index::Colon, lowered_index::Base.Slice) =
484-
ConcretizedSlice(lowered_index)
489+
function reconcretize_index(original_index::Colon, lowered_index::Base.Slice)
490+
return ConcretizedSlice(lowered_index)
491+
end
485492

486493
"""
487494
concretize(l, x)
@@ -495,7 +502,9 @@ the result close to the original indexing.
495502
"""
496503
concretize(I::ALLOWED_OPTICS, x) = I
497504
concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x)
498-
concretize(I::IndexLens, x) = IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
505+
function concretize(I::IndexLens, x)
506+
return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
507+
end
499508
function concretize(I::ComposedOptic, x)
500509
x_inner = I.inner(x) # TODO: get view here
501510
return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x))
@@ -646,11 +655,9 @@ function varname(expr::Expr, concretize=Accessors.need_dynamic_optic(expr))
646655
end
647656

648657
if concretize
649-
return :(
650-
$(AbstractPPL.VarName){$sym}(
658+
return :($(AbstractPPL.VarName){$sym}(
651659
$(AbstractPPL.concretize)($optics, $sym_escaped)
652-
)
653-
)
660+
))
654661
elseif Accessors.need_dynamic_optic(expr)
655662
error("Variable name `$(expr)` is dynamic and requires concretization!")
656663
else
@@ -672,7 +679,7 @@ end
672679
function _parse_obj_optic(ex)
673680
obj, optics = _parse_obj_optics(ex)
674681
optic = Expr(:call, Accessors.opticcompose, optics...)
675-
obj, optic
682+
return obj, optic
676683
end
677684

678685
# Accessors doesn't have the same support for interpolation
@@ -688,7 +695,8 @@ function _parse_obj_optics(ex)
688695
indices = Accessors.replace_underscore.(indices, collection)
689696
dims = length(indices) == 1 ? nothing : 1:length(indices)
690697
lindices = esc.(Accessors.lower_index.(collection, indices, dims))
691-
optics = :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),)))
698+
optics =
699+
:($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),)))
692700
else
693701
index = esc(Expr(:tuple, indices...))
694702
optics = :($(Accessors.IndexLens)($index))
@@ -702,16 +710,20 @@ function _parse_obj_optics(ex)
702710
elseif Meta.isexpr(property, :$, 1)
703711
optics = :($(Accessors.PropertyLens){$(esc(property.args[1]))}())
704712
else
705-
throw(ArgumentError(
706-
string("Error while parsing :($ex). Second argument to `getproperty` can only be",
707-
"a `Symbol` or `String` literal, received `$property` instead.")
708-
))
713+
throw(
714+
ArgumentError(
715+
string(
716+
"Error while parsing :($ex). Second argument to `getproperty` can only be",
717+
"a `Symbol` or `String` literal, received `$property` instead.",
718+
),
719+
),
720+
)
709721
end
710722
else
711723
obj = esc(ex)
712724
return obj, ()
713725
end
714-
obj, tuple(frontoptics..., optics)
726+
return obj, tuple(frontoptics..., optics)
715727
end
716728

717729
"""
@@ -778,12 +790,27 @@ Convert an index `i` to a dictionary representation.
778790
"""
779791
index_to_dict(i::Integer) = Dict("type" => _BASE_INTEGER_TYPE, "value" => i)
780792
index_to_dict(v::Vector{Int}) = Dict("type" => _BASE_VECTOR_TYPE, "values" => v)
781-
index_to_dict(r::UnitRange) = Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop)
782-
index_to_dict(r::StepRange) = Dict("type" => _BASE_STEPRANGE_TYPE, "start" => r.start, "stop" => r.stop, "step" => r.step)
783-
index_to_dict(r::Base.OneTo{I}) where {I} = Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop)
793+
function index_to_dict(r::UnitRange)
794+
return Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop)
795+
end
796+
function index_to_dict(r::StepRange)
797+
return Dict(
798+
"type" => _BASE_STEPRANGE_TYPE,
799+
"start" => r.start,
800+
"stop" => r.stop,
801+
"step" => r.step,
802+
)
803+
end
804+
function index_to_dict(r::Base.OneTo{I}) where {I}
805+
return Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop)
806+
end
784807
index_to_dict(::Colon) = Dict("type" => _BASE_COLON_TYPE)
785-
index_to_dict(s::ConcretizedSlice{T,R}) where {T,R} = Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range))
786-
index_to_dict(t::Tuple) = Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t))
808+
function index_to_dict(s::ConcretizedSlice{T,R}) where {T,R}
809+
return Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range))
810+
end
811+
function index_to_dict(t::Tuple)
812+
return Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t))
813+
end
787814

788815
"""
789816
dict_to_index(dict)
@@ -839,9 +866,17 @@ function dict_to_index(dict)
839866
end
840867

841868
optic_to_dict(::typeof(identity)) = Dict("type" => "identity")
842-
optic_to_dict(::PropertyLens{sym}) where {sym} = Dict("type" => "property", "field" => String(sym))
869+
function optic_to_dict(::PropertyLens{sym}) where {sym}
870+
return Dict("type" => "property", "field" => String(sym))
871+
end
843872
optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices))
844-
optic_to_dict(c::ComposedOptic) = Dict("type" => "composed", "outer" => optic_to_dict(c.outer), "inner" => optic_to_dict(c.inner))
873+
function optic_to_dict(c::ComposedOptic)
874+
return Dict(
875+
"type" => "composed",
876+
"outer" => optic_to_dict(c.outer),
877+
"inner" => optic_to_dict(c.inner),
878+
)
879+
end
845880

846881
function dict_to_optic(dict)
847882
if dict["type"] == "identity"
@@ -857,9 +892,13 @@ function dict_to_optic(dict)
857892
end
858893
end
859894

860-
varname_to_dict(vn::VarName) = Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn)))
895+
function varname_to_dict(vn::VarName)
896+
return Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn)))
897+
end
861898

862-
dict_to_varname(dict::Dict{<:AbstractString, Any}) = VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"]))
899+
function dict_to_varname(dict::Dict{<:AbstractString,Any})
900+
return VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"]))
901+
end
863902

864903
"""
865904
varname_to_string(vn::VarName)

test/runtests.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ using Test
1616
include("abstractprobprog.jl")
1717
@testset "doctests" begin
1818
DocMeta.setdocmeta!(
19-
AbstractPPL,
20-
:DocTestSetup,
21-
:(using AbstractPPL);
22-
recursive=true,
19+
AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true
2320
)
2421
doctest(AbstractPPL; manual=false)
2522
end

0 commit comments

Comments
 (0)