Skip to content

Commit 748b191

Browse files
committed
Support functors (#398)
Fixes #367. Additionally, I removed the `name` field of `Model` since it seemed redundant with `nameof(model.f)` (if `model.f isa Function`) and `Symbol(model.f)` (otherwise). This could be separated or reverted. TODO: - [x] Add tests Co-authored-by: David Widmann <[email protected]>
1 parent de40505 commit 748b191

10 files changed

+63
-31
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.18.0"
3+
version = "0.19.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

benchmarks/benchmark_body.jmd

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ results["evaluation_typed"]
2020
```
2121

2222
```julia; echo=false; results="hidden";
23-
BenchmarkTools.save(joinpath("results", WEAVE_ARGS[:name], "$(m.name)_benchmarks.json"), results)
23+
BenchmarkTools.save(joinpath("results", WEAVE_ARGS[:name], "$(nameof(m))_benchmarks.json"), results)
2424
```
2525

2626
```julia; wrap=false
@@ -32,15 +32,15 @@ end
3232
```julia; echo=false; results="hidden"
3333
if WEAVE_ARGS[:include_typed_code]
3434
# Serialize the output of `typed_code` so we can compare later.
35-
haskey(WEAVE_ARGS, :name) && serialize(joinpath("results", WEAVE_ARGS[:name],"$(m.name).jls"), string(typed));
35+
haskey(WEAVE_ARGS, :name) && serialize(joinpath("results", WEAVE_ARGS[:name],"$(nameof(m)).jls"), string(typed));
3636
end
3737
```
3838

3939
```julia; wrap=false; echo=false;
4040
if haskey(WEAVE_ARGS, :name_old)
4141
# We want to compare the generated code to the previous version.
4242
import DiffUtils
43-
typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(m.name).jls"));
43+
typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(nameof(m)).jls"));
4444
DiffUtils.diff(typed_old, string(typed), width=130)
4545
end
4646
```

src/compiler.jl

+12-7
Original file line numberDiff line numberDiff line change
@@ -594,19 +594,24 @@ function build_output(modelinfo, linenumbernode)
594594
allargs_namedtuple = modelinfo[:allargs_namedtuple]
595595
defaults_namedtuple = modelinfo[:defaults_namedtuple]
596596

597+
# Obtain or generate the name of the model to support functors:
598+
# https://github.com/TuringLang/DynamicPPL.jl/issues/367
599+
modeldef = modelinfo[:modeldef]
600+
if MacroTools.@capture(modeldef[:name], ::T_)
601+
name = gensym(:f)
602+
modeldef[:name] = Expr(:(::), name, T)
603+
elseif MacroTools.@capture(modeldef[:name], (name_::_ | name_))
604+
else
605+
throw(ArgumentError("unsupported format of model function"))
606+
end
607+
597608
# Update the function body of the user-specified model.
598609
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
599610
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
600611
# to the call site
601-
modeldef = modelinfo[:modeldef]
602612
modeldef[:body] = MacroTools.@q begin
603613
$(linenumbernode)
604-
return $(DynamicPPL.Model)(
605-
$(QuoteNode(modeldef[:name])),
606-
$(modeldef[:name]),
607-
$allargs_namedtuple,
608-
$defaults_namedtuple,
609-
)
614+
return $(DynamicPPL.Model)($name, $allargs_namedtuple, $defaults_namedtuple)
610615
end
611616

612617
return MacroTools.@q begin

src/model.jl

+10-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""
22
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults}
3-
name::Symbol
43
f::F
54
args::NamedTuple{argnames,Targs}
65
defaults::NamedTuple{defaultnames,Tdefaults}
@@ -34,53 +33,49 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x
3433
"""
3534
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
3635
AbstractProbabilisticProgram
37-
name::Symbol
3836
f::F
3937
args::NamedTuple{argnames,Targs}
4038
defaults::NamedTuple{defaultnames,Tdefaults}
4139
context::Ctx
4240

4341
@doc """
44-
Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple)
42+
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
4543
46-
Create a model of name `name` with evaluation function `f` and missing arguments
47-
overwritten by `missings`.
44+
Create a model with evaluation function `f` and missing arguments overwritten by
45+
`missings`.
4846
"""
4947
function Model{missings}(
50-
name::Symbol,
5148
f::F,
5249
args::NamedTuple{argnames,Targs},
5350
defaults::NamedTuple{defaultnames,Tdefaults},
5451
context::Ctx=DefaultContext(),
5552
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
5653
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
57-
name, f, args, defaults, context
54+
f, args, defaults, context
5855
)
5956
end
6057
end
6158

6259
"""
63-
Model(name::Symbol, f, args::NamedTuple[, defaults::NamedTuple = ()])
60+
Model(f, args::NamedTuple[, defaults::NamedTuple = ()])
6461
65-
Create a model of name `name` with evaluation function `f` and missing arguments deduced
66-
from `args`.
62+
Create a model with evaluation function `f` and missing arguments deduced from `args`.
6763
6864
Default arguments `defaults` are used internally when constructing instances of the same
6965
model with different arguments.
7066
"""
7167
@generated function Model(
72-
name::Symbol,
7368
f::F,
7469
args::NamedTuple{argnames,Targs},
7570
defaults::NamedTuple=NamedTuple(),
7671
context::AbstractContext=DefaultContext(),
7772
) where {F,argnames,Targs}
7873
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
79-
return :(Model{$missings}(name, f, args, defaults, context))
74+
return :(Model{$missings}(f, args, defaults, context))
8075
end
8176

8277
function contextualize(model::Model, context::AbstractContext)
83-
return Model(model.name, model.f, model.args, model.defaults, context)
78+
return Model(model.f, model.args, model.defaults, context)
8479
end
8580

8681
"""
@@ -518,7 +513,8 @@ getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missing
518513
519514
Get the name of the `model` as `Symbol`.
520515
"""
521-
Base.nameof(model::Model) = model.name
516+
Base.nameof(model::Model) = Symbol(model.f)
517+
Base.nameof(model::Model{<:Function}) = nameof(model.f)
522518

523519
"""
524520
rand([rng=Random.GLOBAL_RNG], [T=NamedTuple], model::Model)

src/prob_macro.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ end
179179
return quote
180180
$(warnings...)
181181
Model{$(Tuple(missings))}(
182-
model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
182+
model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
183183
)
184184
end
185185
end
@@ -237,6 +237,6 @@ end
237237
# `args` is inserted as properly typed NamedTuple expression;
238238
# `missings` is splatted into a tuple at compile time and inserted as literal
239239
return :(Model{$(Tuple(missings))}(
240-
model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
240+
model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults
241241
))
242242
end

src/submodel_macro.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ julia> @varname(var"my prefix.x") in keys(VarInfo(outer()))
148148
true
149149
150150
julia> # Using string interpolation.
151-
@model outer() = @submodel prefix="\$(inner().name)" a = inner()
151+
@model outer() = @submodel prefix="\$(nameof(inner()))" a = inner()
152152
outer (generic function with 2 methods)
153153
154154
julia> @varname(var"inner.x") in keys(VarInfo(outer()))

src/test_utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function test_sampler_demo_models(
290290
rtol=1e-3,
291291
kwargs...,
292292
)
293-
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS
293+
@testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS
294294
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...)
295295
μ = meanfunction(chain)
296296
@test μ target atol = atol rtol = rtol

test/model.jl

+30
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# some functors (#367)
2+
struct MyModel
3+
a::Int
4+
end
5+
@model function (f::MyModel)(x)
6+
m ~ Normal(f.a, 1)
7+
return x ~ Normal(m, 1)
8+
end
9+
struct MyZeroModel end
10+
@model function (::MyZeroModel)(x)
11+
m ~ Normal(0, 1)
12+
return x ~ Normal(m, 1)
13+
end
14+
115
@testset "model.jl" begin
216
@testset "convenience functions" begin
317
model = gdemo_default
@@ -61,9 +75,25 @@
6175
m ~ Normal(0, 1)
6276
x ~ Normal(m, 1)
6377
end
78+
function test3 end
79+
@model function (::typeof(test3))(x)
80+
m ~ Normal(0, 1)
81+
return x ~ Normal(m, 1)
82+
end
83+
function test4 end
84+
@model function (a::typeof(test4))(x)
85+
m ~ Normal(0, 1)
86+
return x ~ Normal(m, 1)
87+
end
6488

6589
@test nameof(test1(rand())) == :test1
6690
@test nameof(test2(rand())) == :test2
91+
@test nameof(test3(rand())) == :test3
92+
@test nameof(test4(rand())) == :test4
93+
94+
# callables
95+
@test nameof(MyModel(3)(rand())) == Symbol("MyModel(3)")
96+
@test nameof(MyZeroModel()(rand())) == Symbol("MyZeroModel()")
6797
end
6898

6999
@testset "Internal methods" begin

test/simple_varinfo.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
end
5959
end
6060

61-
@testset "SimpleVarInfo on $(model.name)" for model in DynamicPPL.TestUtils.DEMO_MODELS
61+
@testset "SimpleVarInfo on $(nameof(model))" for model in
62+
DynamicPPL.TestUtils.DEMO_MODELS
6263
# We might need to pre-allocate for the variable `m`, so we need
6364
# to see whether this is the case.
6465
m = model().m

test/turing/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
66

77
[compat]
8-
DynamicPPL = "0.18"
8+
DynamicPPL = "0.19"
99
Turing = "0.18, 0.19, 0.20"
1010
julia = "1.3"

0 commit comments

Comments
 (0)