Skip to content

Commit 1a23326

Browse files
committed
Add the Enzyme rules example.
1 parent cc4c332 commit 1a23326

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

julia_custom/custom.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using Enzyme
2+
using Enzyme: EnzymeRules
3+
4+
# Defining our function
5+
f(x) = x^2;
6+
7+
function f_ip(x)
8+
x[1] *= x[1]
9+
return nothing
10+
end
11+
12+
import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule, has_rrule_from_sig
13+
using .EnzymeRules
14+
15+
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active)
16+
if needs_primal(config)
17+
return AugmentedReturn(func.val(x.val), nothing, nothing)
18+
else
19+
return AugmentedReturn(nothing, nothing, nothing)
20+
end
21+
end
22+
23+
function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active)
24+
if needs_primal(config)
25+
return (10+2*x.val*dret.val,)
26+
else
27+
return (100+2*x.val*dret.val,)
28+
end
29+
end
30+
31+
function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated)
32+
v = x.val[1]
33+
x.val[1] *= v
34+
return AugmentedReturn(nothing, nothing, v)
35+
end
36+
37+
function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated)
38+
x.dval[1] = 100 + x.dval[1] * tape
39+
return ()
40+
end
41+
42+
# To which we can then apply the Enzyme calls
43+
Enzyme.autodiff(Enzyme.Reverse, f, Active(2.0))[1][1];
44+
Enzyme.autodiff(Enzyme.Reverse, x->f(x)^2, Active(2.0))[1][1];
45+
46+
x = [2.0];
47+
dx = [1.0];
48+
49+
Enzyme.autodiff(Enzyme.Reverse, f_ip, Duplicated(x, dx));

0 commit comments

Comments
 (0)