-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathMeasureBaseChainRulesCoreExt.jl
59 lines (41 loc) · 2.26 KB
/
MeasureBaseChainRulesCoreExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
module MeasureBaseChainRulesCoreExt
using MeasureBase
using ChainRulesCore: NoTangent, ZeroTangent
import ChainRulesCore
# = utils ====================================================================
using MeasureBase: isneginf, isposinf
_isneginf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isneginf), x) = isneginf(x), _logdensityof_rt_pullback
_isposinf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isposinf), x) = isposinf(x), _isposinf_pullback
# = insupport & friends ======================================================
using MeasureBase: check_dof, require_insupport, checked_arg, _checksupport, _origin_depth
@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
y = _checksupport(cond, result)
function _checksupport_pullback(ȳ)
return NoTangent(), ZeroTangent(), one(ȳ)
end
y, _checksupport_pullback
end
_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
return require_insupport(μ, x), _require_insupport_pullback
end
_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback
_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback
# = return type inference ====================================================
using MeasureBase: logdensityof_rt, strict_logdensityof_rt
_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v)
logdensityof_rt(target, v), _logdensityof_rt_pullback
end
_strict_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(strict_logdensityof_rt), target, v)
strict_logdensityof_rt(target, v), _strict_logdensityof_rt_pullback
end
end # module MeasureBaseChainRulesCoreExt