|
| 1 | +module UtilsTests |
| 2 | + |
| 3 | +using ReverseDiffPrototype, Base.Test |
| 4 | +using ReverseDiffPrototype: Tape, TapeNode, Tracked, track, track!, value, adjoint, tape |
| 5 | + |
| 6 | +const RDP = ReverseDiffPrototype |
| 7 | + |
| 8 | +include("utils.jl") |
| 9 | + |
| 10 | +println("testing utilities...") |
| 11 | +tic() |
| 12 | + |
| 13 | +############################################################################################ |
| 14 | + |
| 15 | +tracked_is(a, b) = value(a) === value(b) && adjoint(a) === adjoint(b) && tape(a) === tape(b) |
| 16 | +tracked_is(a::AbstractArray, b::AbstractArray) = all(map(tracked_is, a, b)) |
| 17 | +tracked_is(a::Tuple, b::Tuple) = all(map(tracked_is, a, b)) |
| 18 | + |
| 19 | +################ |
| 20 | +# track/track! # |
| 21 | +################ |
| 22 | + |
| 23 | +tp = Tape() |
| 24 | +ntp = Nullable(tp) |
| 25 | +n, m = rand(), rand() |
| 26 | +x, y = rand(3), rand(3, 3) |
| 27 | +tn_int, tm_int = Tracked(n, Int, ntp), Tracked(m, Int, ntp) |
| 28 | +tn_float, tm_float = Tracked(n, ntp), Tracked(m, ntp) |
| 29 | +tx_int, ty_int = map(t -> Tracked(t, Int, ntp), x), map(t -> Tracked(t, Int, ntp), y) |
| 30 | +tx_float, ty_float = map(t -> Tracked(t, ntp), x), map(t -> Tracked(t, ntp), y) |
| 31 | + |
| 32 | +@test tracked_is(track(n, tp), tn_float) |
| 33 | +@test tracked_is(track(n, Int, tp), tn_int) |
| 34 | +@test tracked_is(track(n, Int, ntp), tn_int) |
| 35 | + |
| 36 | +@test tracked_is(track(x, tp), tx_float) |
| 37 | +@test tracked_is(track(x, Int, tp), tx_int) |
| 38 | +@test tracked_is(track(x, Int, ntp), tx_int) |
| 39 | + |
| 40 | +@test tracked_is(track((n, m), tp), (tn_float, tm_float)) |
| 41 | +@test tracked_is(track((n, m), Int, tp), (tn_int, tm_int)) |
| 42 | + |
| 43 | +@test tracked_is(track((x, y), tp), (tx_float, ty_float)) |
| 44 | +@test tracked_is(track((x, y), Int, tp), (tx_int, ty_int)) |
| 45 | + |
| 46 | +tx_int_sim = similar(tx_int) |
| 47 | +tx_float_sim = similar(tx_float) |
| 48 | +track!(tx_int_sim, x, tp) |
| 49 | +@test tracked_is(tx_int_sim, tx_int) |
| 50 | +track!(tx_float_sim, x, tp) |
| 51 | +@test tracked_is(tx_float_sim, tx_float) |
| 52 | + |
| 53 | +tx_int_sim = similar(tx_int) |
| 54 | +tx_float_sim = similar(tx_float) |
| 55 | +track!(tx_int_sim, x, ntp) |
| 56 | +@test tracked_is(tx_int_sim, tx_int) |
| 57 | +track!(tx_float_sim, x, ntp) |
| 58 | +@test tracked_is(tx_float_sim, tx_float) |
| 59 | + |
| 60 | +tx_int_sim, ty_int_sim = similar(tx_int), similar(ty_int) |
| 61 | +tx_float_sim, ty_float_sim = similar(tx_float), similar(ty_float) |
| 62 | +track!((tx_int_sim, ty_int_sim), (x, y), tp) |
| 63 | +@test tracked_is(tx_int_sim, tx_int) |
| 64 | +@test tracked_is(ty_int_sim, ty_int) |
| 65 | +track!((tx_float_sim, ty_float_sim), (x, y), tp) |
| 66 | +@test tracked_is(tx_float_sim, tx_float) |
| 67 | +@test tracked_is(ty_float_sim, ty_float) |
| 68 | + |
| 69 | +tx_int_sim, ty_int_sim = similar(tx_int), similar(ty_int) |
| 70 | +tx_float_sim, ty_float_sim = similar(tx_float), similar(ty_float) |
| 71 | +track!((tx_int_sim, ty_int_sim), (x, y), ntp) |
| 72 | +@test tracked_is(tx_int_sim, tx_int) |
| 73 | +@test tracked_is(ty_int_sim, ty_int) |
| 74 | +track!((tx_float_sim, ty_float_sim), (x, y), ntp) |
| 75 | +@test tracked_is(tx_float_sim, tx_float) |
| 76 | +@test tracked_is(ty_float_sim, ty_float) |
| 77 | + |
| 78 | +track!(tn_int, m, tp) |
| 79 | +@test tracked_is(tn_int, tm_int) |
| 80 | +track!(tn_int, n, tp) |
| 81 | + |
| 82 | +track!(tn_float, m, tp) |
| 83 | +@test tracked_is(tn_float, tm_float) |
| 84 | +track!(tn_float, n, tp) |
| 85 | + |
| 86 | +track!((tm_int, tn_int), (n, m), tp) |
| 87 | +@test value(tm_int) === n |
| 88 | +@test value(tn_int) === m |
| 89 | +track!((tm_int, tn_int), (m, n), tp) |
| 90 | + |
| 91 | +track!(tn_int, m, ntp) |
| 92 | +@test tracked_is(tn_int, tm_int) |
| 93 | +track!(tn_int, n, ntp) |
| 94 | + |
| 95 | +track!(tn_float, m, ntp) |
| 96 | +@test tracked_is(tn_float, tm_float) |
| 97 | +track!(tn_float, n, ntp) |
| 98 | + |
| 99 | +track!((tm_int, tn_int), (n, m), ntp) |
| 100 | +@test value(tm_int) === n |
| 101 | +@test value(tn_int) === m |
| 102 | +track!((tm_int, tn_int), (m, n), ntp) |
| 103 | + |
| 104 | +################################## |
| 105 | +# array accessors/tape selection # |
| 106 | +################################## |
| 107 | + |
| 108 | +tp = Tape() |
| 109 | +ntp = Nullable(tp) |
| 110 | +x, y = rand(3), rand(3, 2) |
| 111 | +tx, ty = track(x, tp), track(y, tp) |
| 112 | + |
| 113 | +@test value(tx) == x |
| 114 | +@test value(ty) == y |
| 115 | + |
| 116 | +x_sim, y_sim = similar(x), similar(y) |
| 117 | +RDP.value!(x_sim, tx) |
| 118 | +RDP.value!(y_sim, ty) |
| 119 | +@test x_sim == x |
| 120 | +@test y_sim == y |
| 121 | + |
| 122 | +@test all(adjoint(tx) .== zero(first(x))) |
| 123 | +@test all(adjoint(ty) .== zero(first(y))) |
| 124 | + |
| 125 | +x_sim, y_sim = similar(x), similar(y) |
| 126 | +RDP.adjoint!(x_sim, tx) |
| 127 | +RDP.adjoint!(y_sim, ty) |
| 128 | +@test all(x_sim .== zero(first(x))) |
| 129 | +@test all(y_sim .== zero(first(y))) |
| 130 | + |
| 131 | +@test tape(tx) === ntp |
| 132 | +@test tape(tx, ty) === ntp |
| 133 | +@test tape(tx, first(tx)) === ntp |
| 134 | +@test tape(first(tx), tx) === ntp |
| 135 | +@test tape(tx, Tracked(1)) === ntp |
| 136 | +@test tape(Tracked(1), tx) === ntp |
| 137 | +@test tape(tx, [Tracked(1)]) === ntp |
| 138 | +@test tape([Tracked(1)], tx) === ntp |
| 139 | +@test tape([Tracked(1), first(tx)]) === ntp |
| 140 | +@test isnull(tape([Tracked(1)])) |
| 141 | +@test isnull(tape([Tracked(1)], [Tracked(1)])) |
| 142 | + |
| 143 | +################# |
| 144 | +# seed!/unseed! # |
| 145 | +################# |
| 146 | + |
| 147 | +tp = Tape() |
| 148 | +ntp = Nullable(tp) |
| 149 | + |
| 150 | +@test tracked_is(RDP.seed!(Tracked(1, 0, ntp)), Tracked(1, 1, ntp)) |
| 151 | + |
| 152 | +node = TapeNode(+, (Tracked(2, ntp), Tracked(1, ntp)), Tracked(3, ntp), nothing) |
| 153 | +@test tracked_is(RDP.seed!(node).outputs, Tracked(3, 1, ntp)) |
| 154 | + |
| 155 | +@test tracked_is(RDP.unseed!(Tracked(1, 1, ntp)), Tracked(1, 0, ntp)) |
| 156 | + |
| 157 | +node = TapeNode(+, (Tracked(2, 2, ntp), Tracked(1, 3, ntp)), Tracked(3, 4, ntp), nothing) |
| 158 | +RDP.unseed!(node) |
| 159 | +@test adjoint(node.inputs[1]) === 0 |
| 160 | +@test adjoint(node.inputs[2]) === 0 |
| 161 | +@test adjoint(node.outputs) === 0 |
| 162 | + |
| 163 | +tp2 = [TapeNode(+, (Tracked(2, 2, ntp), Tracked(1, 3, ntp)), Tracked(3, 4, ntp), nothing), |
| 164 | + TapeNode(+, Tracked(1.0, 3.0, ntp), (Tracked(51.4, 3.1, ntp), Tracked(3, 4, ntp)), nothing)] |
| 165 | +RDP.unseed!(tp2) |
| 166 | +@test adjoint(tp2[1].inputs[1]) === 0 |
| 167 | +@test adjoint(tp2[1].inputs[2]) === 0 |
| 168 | +@test adjoint(tp2[1].outputs) === 0 |
| 169 | +@test adjoint(tp2[2].inputs) === 0.0 |
| 170 | +@test adjoint(tp2[2].outputs[1]) === 0.0 |
| 171 | +@test adjoint(tp2[2].outputs[2]) === 0 |
| 172 | + |
| 173 | +####################### |
| 174 | +# adjoint propagation # |
| 175 | +####################### |
| 176 | + |
| 177 | +tp = Tape() |
| 178 | +ntp = Nullable(tp) |
| 179 | +genarr = () -> [Tracked(rand(), rand(), ntp) for i in 1:3] |
| 180 | + |
| 181 | +x, y = genarr(), genarr() |
| 182 | +xadj, yadj = adjoint(x), adjoint(y) |
| 183 | +RDP.extract_and_decrement_adjoint!(x, y) |
| 184 | +@test adjoint(x) == (xadj - yadj) |
| 185 | + |
| 186 | +x, y = genarr(), genarr() |
| 187 | +xadj, yadj = adjoint(x), adjoint(y) |
| 188 | +RDP.extract_and_increment_adjoint!(x, y) |
| 189 | +@test adjoint(x) == (xadj + yadj) |
| 190 | + |
| 191 | +x, y = genarr(), rand(3) |
| 192 | +xadj = adjoint(x) |
| 193 | +RDP.increment_adjoint!(x, y) |
| 194 | +@test adjoint(x) == (xadj + y) |
| 195 | + |
| 196 | +x = genarr() |
| 197 | +xadj = adjoint(x) |
| 198 | +RDP.increment_adjoint!(x, 3) |
| 199 | +@test adjoint(x) == (xadj + 3) |
| 200 | + |
| 201 | +k = Tracked(1, 3, ntp) |
| 202 | +RDP.increment_adjoint!(k, 3) |
| 203 | +@test tracked_is(k, Tracked(1, 6, ntp)) |
| 204 | + |
| 205 | +############################################################################################ |
| 206 | + |
| 207 | +println("done (took $(toq()) seconds)") |
| 208 | + |
| 209 | +end # module |
0 commit comments