|
| 1 | +import DifferentiationInterface as DI |
| 2 | + |
1 | 3 | """ |
2 | 4 | LogDensityFunction |
3 | 5 |
|
|
81 | 83 |
|
82 | 84 | Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. |
83 | 85 | """ |
84 | | -getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = |
85 | | - getmodel(LogDensityProblemsAD.parent(f)) |
86 | 86 | getmodel(f::DynamicPPL.LogDensityFunction) = f.model |
87 | 87 |
|
88 | 88 | """ |
89 | 89 | setmodel(f, model[, adtype]) |
90 | 90 |
|
91 | 91 | Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. |
92 | | -
|
93 | | -!!! warning |
94 | | - Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a |
95 | | - `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` |
96 | | - might require recompilation of the gradient tape, depending on the AD backend. |
97 | 92 | """ |
98 | | -function setmodel( |
99 | | - f::LogDensityProblemsAD.ADGradientWrapper, |
100 | | - model::DynamicPPL.Model, |
101 | | - adtype::ADTypes.AbstractADType, |
102 | | -) |
103 | | - # TODO: Should we handle `SciMLBase.NoAD`? |
104 | | - # For an `ADGradientWrapper` we do the following: |
105 | | - # 1. Update the `Model` in the underlying `LogDensityFunction`. |
106 | | - # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` |
107 | | - # to ensure that the recompilation of gradient tapes, etc. also occur. For example, |
108 | | - # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just |
109 | | - # replacing the corresponding field with the new model won't be sufficient to obtain |
110 | | - # the correct gradients. |
111 | | - return LogDensityProblemsAD.ADgradient( |
112 | | - adtype, setmodel(LogDensityProblemsAD.parent(f), model) |
113 | | - ) |
114 | | -end |
115 | 93 | function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) |
116 | 94 | return Accessors.@set f.model = model |
117 | 95 | end |
@@ -140,18 +118,24 @@ end |
140 | 118 | # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? |
141 | 119 | LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) |
142 | 120 |
|
143 | | -# This is important for performance -- one needs to provide `ADGradient` with a vector of |
144 | | -# parameters, or DifferentiationInterface will not have sufficient information to e.g. |
145 | | -# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate |
146 | | -# a tape when using ReverseDiff.jl. |
147 | | -function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) |
148 | | - x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params |
149 | | - return LogDensityProblemsAD.ADgradient(ad, ℓ; x) |
150 | | -end |
| 121 | +_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ) |
151 | 122 |
|
152 | | -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) |
153 | | - return _make_ad_gradient(ad, f) |
| 123 | +# By default, the AD backend to use is inferred from the context, which would |
| 124 | +# typically be a SamplingContext which contains a sampler. |
| 125 | +function LogDensityProblems.logdensity_and_gradient( |
| 126 | + f::LogDensityFunction, θ::AbstractVector |
| 127 | +) |
| 128 | + adtype = getadtype(getsampler(getcontext(f))) |
| 129 | + return LogDensityProblems.logdensity_and_gradient(f, θ, adtype) |
154 | 130 | end |
155 | | -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) |
156 | | - return _make_ad_gradient(ad, f) |
| 131 | + |
| 132 | +# Extra method allowing one to manually specify the AD backend to use, thus |
| 133 | +# overriding the default AD backend inferred from the sampler. |
| 134 | +function LogDensityProblems.logdensity_and_gradient( |
| 135 | + f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType |
| 136 | +) |
| 137 | + # Ensure we concretise the elements of the params. |
| 138 | + θ = map(identity, getparams(f)) |
| 139 | + prep = DI.prepare_gradient(_flipped_logdensity, adtype, params, DI.Constant(f)) |
| 140 | + return DI.value_and_gradient(_flipped_logdensity, prep, adtype, params, DI.Constant(f)) |
157 | 141 | end |
0 commit comments