@@ -94,12 +94,43 @@ function DynamicPPL.tilde_observe(
94
94
end
95
95
96
96
"""
97
- OptimLogDensity{M<:DynamicPPL.Model,V<:DynamicPPL.VarInfo,C<:OptimizationContext,AD<:ADTypes.AbstractADType}
97
+ OptimLogDensity{
98
+ M<:DynamicPPL.Model,
99
+ V<:DynamicPPL.VarInfo,
100
+ C<:OptimizationContext,
101
+ AD<:ADTypes.AbstractADType
102
+ }
103
+
104
+ A struct that wraps a single LogDensityFunction. Can be invoked either using
105
+
106
+ ```julia
107
+ OptimLogDensity(model, varinfo, ctx; adtype=adtype)
108
+ ```
109
+
110
+ or
111
+
112
+ ```julia
113
+ OptimLogDensity(model, ctx; adtype=adtype)
114
+ ```
115
+
116
+ If not specified, `adtype` defaults to `AutoForwardDiff()`.
117
+
118
+ An OptimLogDensity does not, in itself, obey the LogDensityProblems interface.
119
+ Thus, if you want to calculate the log density of its contents at the point
120
+ `z`, you should manually call
121
+
122
+ ```julia
123
+ LogDensityProblems.logdensity(f.ldf, z)
124
+ ```
98
125
99
- A struct that stores the negative log density function of a `DynamicPPL` model.
126
+ However, it is a callable object which returns the *negative* log density of
127
+ the underlying LogDensityFunction at the point `z`. This is done to satisfy
128
+ the Optim.jl interface.
100
129
101
- TODO(penelopeysm): It _doesn't_ really store the negative, does it? It's more like we
102
- overrode logdensity to give the negative logdensity.
130
+ ```julia
131
+ optim_ld = OptimLogDensity(model, varinfo, ctx)
132
+ optim_ld(z) # returns -logp
133
+ ```
103
134
"""
104
135
struct OptimLogDensity{
105
136
M<: DynamicPPL.Model ,
@@ -114,7 +145,7 @@ function OptimLogDensity(
114
145
model:: DynamicPPL.Model ,
115
146
vi:: DynamicPPL.VarInfo ,
116
147
ctx:: OptimizationContext ;
117
- adtype:: Union{Nothing, ADTypes.AbstractADType} = AutoForwardDiff (),
148
+ adtype:: ADTypes.AbstractADType = AutoForwardDiff (),
118
149
)
119
150
return OptimLogDensity (Turing. LogDensityFunction (model, vi, ctx; adtype= adtype))
120
151
end
123
154
function OptimLogDensity (
124
155
model:: DynamicPPL.Model ,
125
156
ctx:: OptimizationContext ;
126
- adtype:: Union{Nothing, ADTypes.AbstractADType} = AutoForwardDiff (),
157
+ adtype:: ADTypes.AbstractADType = AutoForwardDiff (),
127
158
)
128
159
return OptimLogDensity (
129
160
Turing. LogDensityFunction (model, DynamicPPL. VarInfo (model), ctx; adtype= adtype)
0 commit comments