@@ -52,7 +52,84 @@ using ..DistributionsAD: TuringPoissonBinomial,
52
52
VectorOfMultivariate,
53
53
FillVectorOfMultivariate
54
54
55
- include (" reversediffx.jl" )
55
+ # ##############
56
+ # # logsumexp ##
57
+ # ##############
58
+
59
+ logsumexp (x:: TrackedArray ; dims= :) = track (logsumexp, x, dims = dims)
60
+ @grad function logsumexp (x:: AbstractArray ; dims)
61
+ x_value = value (x)
62
+ lse = logsumexp (x_value; dims= dims)
63
+ return lse, Δ -> (Δ .* exp .(x_value .- lse),)
64
+ end
65
+
66
+ # ###########
67
+ # # linalg ##
68
+ # ###########
69
+
70
+ function LinearAlgebra. cholesky (A:: Symmetric{<:Any, <:TrackedMatrix} ; check= true )
71
+ uplo = A. uplo == ' U' ? (:U ) : (:L )
72
+ factors, info = symm_turing_chol (parent (A), check, uplo)
73
+ return Cholesky {eltype(factors), typeof(factors)} (factors, ' U' , info)
74
+ end
75
+ function LinearAlgebra. cholesky (A:: TrackedMatrix ; check= true )
76
+ factors, info = turing_chol (A, check)
77
+ return Cholesky {eltype(factors), typeof(factors)} (factors, ' U' , info)
78
+ end
79
+
80
+ function symm_turing_chol (x:: TrackedArray{V,D} , check, uplo) where {V,D}
81
+ tp = tape (x)
82
+ x_value = value (x)
83
+ (factors,info), back = DistributionsAD. symm_turing_chol_back (x_value, check, uplo)
84
+ C = Cholesky {eltype(factors), typeof(factors)} (factors, ' U' , info)
85
+ out = track (C. factors, D, tp)
86
+ record! (tp, SpecialInstruction, symm_turing_chol, (x, check, uplo), out, (back, issuccess (C)))
87
+ return out, C. info
88
+ end
89
+ function turing_chol (x:: TrackedArray{V,D} , check) where {V,D}
90
+ tp = tape (x)
91
+ x_value = value (x)
92
+ (factors,info), back = DistributionsAD. turing_chol_back (x_value, check)
93
+ C = Cholesky {eltype(factors), typeof(factors)} (factors, ' U' , info)
94
+ out = track (C. factors, D, tp)
95
+ record! (tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess (C)))
96
+ return out, C. info
97
+ end
98
+
99
+ for f in (:turing_chol , :symm_turing_chol )
100
+ @eval begin
101
+ @noinline function ReverseDiff. special_reverse_exec! (
102
+ instruction:: SpecialInstruction{typeof($f)} ,
103
+ )
104
+ output = instruction. output
105
+ instruction. cache[2 ] || throw (PosDefException (C. info))
106
+ input = instruction. input
107
+ input_deriv = deriv (input[1 ])
108
+ P = instruction. cache[1 ]
109
+ input_deriv .+ = P ((factors = deriv (output),))[1 ]
110
+ unseed! (output)
111
+ return nothing
112
+ end
113
+ end
114
+ end
115
+
116
+ @noinline function ReverseDiff. special_forward_exec! (
117
+ instruction:: SpecialInstruction{typeof(turing_chol)} ,
118
+ )
119
+ output, input = instruction. output, instruction. input
120
+ factors = turing_chol (value .(input)... )[1 ]
121
+ value! (output, factors)
122
+ return nothing
123
+ end
124
+
125
+ @noinline function ReverseDiff. special_forward_exec! (
126
+ instruction:: SpecialInstruction{typeof(symm_turing_chol)} ,
127
+ )
128
+ output, input = instruction. output, instruction. input
129
+ factors = symm_turing_chol (value .(input)... )[1 ]
130
+ value! (output, factors)
131
+ return nothing
132
+ end
56
133
57
134
adapt_randn (rng:: Random.AbstractRNG , x:: TrackedArray , dims... ) = adapt_randn (rng, value (x), dims... )
58
135
0 commit comments