Skip to content

Remove LogDensityProblemsAD; wrap adtype in LogDensityFunction #806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 19, 2025
48 changes: 47 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## 0.35.0

**Breaking**
**Breaking changes**

### `.~` right hand side must be a univariate distribution

Expand Down Expand Up @@ -119,6 +119,52 @@ This release removes the feature of `VarInfo` where it kept track of which varia

This change also affects sampling in Turing.jl.

### `LogDensityFunction` argument order

- The method `LogDensityFunction(varinfo, model, context)` has been removed.
The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`.
(For an explanation of `adtype`, see below.)
The varinfo and context arguments are both still optional.

**Other changes**

### `LogDensityProblems` interface

LogDensityProblemsAD is now removed as a dependency.
Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters.

Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).

However, in this version, `LogDensityFunction` now takes an extra AD type argument.
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
You may thus find that it is easier to instead do this:

```julia
@model f() = ...

ldf = LogDensityFunction(f(); adtype=AutoForwardDiff())
```

This will return an object which satisfies the `LogDensityProblems` interface to first-order, i.e. you can now directly call both

```
LogDensityProblems.logdensity(ldf, params)
LogDensityProblems.logdensity_and_gradient(ldf, params)
```

without having to construct a separate `ADgradient` object.

If you prefer, you can also construct a new `LogDensityFunction` with a new AD type afterwards.
The model, varinfo, and context will be taken from the original `LogDensityFunction`:

```julia
@model f() = ...

ldf = LogDensityFunction(f()) # by default, no adtype set
ldf_with_ad = LogDensityFunction(ldf, AutoForwardDiff())
```

## 0.34.2

- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.
Expand Down
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -51,15 +51,14 @@ Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.41"
Distributions = "0.25"
DocStringExtensions = "0.9"
KernelAbstractions = "0.9.33"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.95"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ logjoint

### LogDensityProblems.jl interface

The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`:
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`.

```@docs
DynamicPPL.LogDensityFunction
Expand Down
72 changes: 29 additions & 43 deletions ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,40 @@
module DynamicPPLForwardDiffExt

if isdefined(Base, :get_extension)
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ForwardDiff
else
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ..ForwardDiff
end

getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk

standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
standardtag(::ADTypes.AutoForwardDiff) = false

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
)
θ = DynamicPPL.getparams(ℓ)
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Define configuration for ForwardDiff.
tag = if standardtag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
using ForwardDiff

# check if the AD type already has a tag
use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false

Check warning on line 8 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L8

Added line #L8 was not covered by tests

function DynamicPPL.tweak_adtype(
ad::ADTypes.AutoForwardDiff{chunk_size},
::DynamicPPL.Model,
vi::DynamicPPL.AbstractVarInfo,
::DynamicPPL.AbstractContext,
) where {chunk_size}
params = vi[:]

# Use DynamicPPL tag to improve stack traces
# https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
# NOTE: DifferentiationInterface disables tag checking if the
# tag inside the AutoForwardDiff type is not nothing. See
# https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350.
# So we don't currently need to override ForwardDiff.checktag as well.
tag = if use_dynamicppl_tag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params))
else
ForwardDiff.Tag(f, eltype(θ))
ad.tag

Check warning on line 27 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L27

Added line #L27 was not covered by tests
end
chunk_size = getchunksize(ad)

# Optimise chunk size according to size of model
chunk = if chunk_size == 0 || chunk_size === nothing
ForwardDiff.Chunk(θ)
ForwardDiff.Chunk(params)
else
ForwardDiff.Chunk(length(θ), chunk_size)
ForwardDiff.Chunk(length(params), chunk_size)

Check warning on line 34 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L34

Added line #L34 was not covered by tests
end

return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
end

# Allow Turing tag in gradient etc. calls of the log density function
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::DynamicPPL.LogDensityFunction,
::AbstractArray{W},
) where {V,W}
return true
end
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
::AbstractArray{W},
) where {V,W}
return true
return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag)
end

end # module
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Accessors: Accessors
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

using LinearAlgebra: LinearAlgebra, Cholesky

Expand Down
1 change: 1 addition & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
getsampler(context::SamplingContext) = context.sampler
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")

Check warning on line 187 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L187

Added line #L187 was not covered by tests

"""
struct DefaultContext <: AbstractContext end
Expand Down
Loading
Loading