Skip to content

Commit f15a006

Browse files
committed
Improve OptimLogDensity docstring
1 parent f7a416e commit f15a006

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

src/optimisation/Optimisation.jl

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,43 @@ function DynamicPPL.tilde_observe(
9494
end
9595

9696
"""
97-
OptimLogDensity{M<:DynamicPPL.Model,V<:DynamicPPL.VarInfo,C<:OptimizationContext,AD<:ADTypes.AbstractADType}
97+
OptimLogDensity{
98+
M<:DynamicPPL.Model,
99+
V<:DynamicPPL.VarInfo,
100+
C<:OptimizationContext,
101+
AD<:ADTypes.AbstractADType
102+
}
103+
104+
A struct that wraps a single LogDensityFunction. Can be invoked either using
105+
106+
```julia
107+
OptimLogDensity(model, varinfo, ctx; adtype=adtype)
108+
```
109+
110+
or
111+
112+
```julia
113+
OptimLogDensity(model, ctx; adtype=adtype)
114+
```
115+
116+
If not specified, `adtype` defaults to `AutoForwardDiff()`.
117+
118+
An OptimLogDensity does not, in itself, obey the LogDensityProblems interface.
119+
Thus, if you want to calculate the log density of its contents at the point
120+
`z`, you should manually call
121+
122+
```julia
123+
LogDensityProblems.logdensity(f.ldf, z)
124+
```
98125
99-
A struct that stores the negative log density function of a `DynamicPPL` model.
126+
However, it is a callable object which returns the *negative* log density of
127+
the underlying LogDensityFunction at the point `z`. This is done to satisfy
128+
the Optim.jl interface.
100129
101-
TODO(penelopeysm): It _doesn't_ really store the negative, does it? It's more like we
102-
overrode logdensity to give the negative logdensity.
130+
```julia
131+
optim_ld = OptimLogDensity(model, varinfo, ctx)
132+
optim_ld(z) # returns -logp
133+
```
103134
"""
104135
struct OptimLogDensity{
105136
M<:DynamicPPL.Model,
@@ -114,7 +145,7 @@ function OptimLogDensity(
114145
model::DynamicPPL.Model,
115146
vi::DynamicPPL.VarInfo,
116147
ctx::OptimizationContext;
117-
adtype::Union{Nothing,ADTypes.AbstractADType}=AutoForwardDiff(),
148+
adtype::ADTypes.AbstractADType=AutoForwardDiff(),
118149
)
119150
return OptimLogDensity(Turing.LogDensityFunction(model, vi, ctx; adtype=adtype))
120151
end
@@ -123,7 +154,7 @@ end
123154
function OptimLogDensity(
124155
model::DynamicPPL.Model,
125156
ctx::OptimizationContext;
126-
adtype::Union{Nothing,ADTypes.AbstractADType}=AutoForwardDiff(),
157+
adtype::ADTypes.AbstractADType=AutoForwardDiff(),
127158
)
128159
return OptimLogDensity(
129160
Turing.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx; adtype=adtype)

0 commit comments

Comments
 (0)