-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathMeasureBaseChainRulesCoreExt.jl
84 lines (58 loc) · 2.65 KB
/
MeasureBaseChainRulesCoreExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
module MeasureBaseChainRulesCoreExt
using MeasureBase
using ChainRulesCore: NoTangent, ZeroTangent
import ChainRulesCore
# = collection utils =========================================================
using MeasureBase: _dropfront, _dropback, _rev_cumsum, _exp_cumsum_log
function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x)
result = _pushfront(v, x)
function _pushfront_pullback(thunked_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
end
return result, _pushfront_pullback
end
function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x)
result = _pushback(v, x)
function _pushback_pullback(thunked_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)])
end
return result, _pushback_pullback
end
function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector)
result = _rev_cumsum(xs)
function _rev_cumsum_pullback(ΔΩ)
∂xs = ChainRulesCore.@thunk cumsum(ChainRulesCore.unthunk(ΔΩ))
(NoTangent(), ∂xs)
end
return result, _rev_cumsum_pullback
end
function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector)
result = _exp_cumsum_log(xs)
function _exp_cumsum_log_pullback(ΔΩ)
∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* ChainRulesCore.unthunk(ΔΩ))
(NoTangent(), ∂xs)
end
return result, _exp_cumsum_log_pullback
end
# = measure functions ========================================================
@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
y = _checksupport(cond, result)
function _checksupport_pullback(ȳ)
return NoTangent(), ZeroTangent(), one(ȳ)
end
y, _checksupport_pullback
end
_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
return require_insupport(μ, x), _require_insupport_pullback
end
_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback
_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback
end # module MeasureBaseChainRulesCoreExt