Skip to content

Commit 96e0f54

Browse files
authored
Merge pull request #626 from JuliaDiff/ox/mutabletangent
Introduce MutableTangent
2 parents 6664e8f + 73b7508 commit 96e0f54

File tree

12 files changed

+1115
-585
lines changed

12 files changed

+1115
-585
lines changed

Project.toml

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,17 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
99

10-
[weakdeps]
11-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
12-
13-
[extensions]
14-
ChainRulesCoreSparseArraysExt = "SparseArrays"
15-
1610
[compat]
1711
BenchmarkTools = "0.5"
18-
Compat = "2, 3, 4"
12+
Compat = "3.40, 4"
1913
FiniteDifferences = "0.10"
2014
OffsetArrays = "1"
2115
StaticArrays = "0.11, 0.12, 1"
2216
julia = "1.6"
2317

18+
[extensions]
19+
ChainRulesCoreSparseArraysExt = "SparseArrays"
20+
2421
[extras]
2522
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
2623
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
@@ -31,3 +28,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3128

3229
[targets]
3330
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"]
31+
32+
[weakdeps]
33+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

docs/make.jl

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ makedocs(;
6161
"`@opt_out`" => "rule_author/superpowers/opt_out.md",
6262
"`RuleConfig`" => "rule_author/superpowers/ruleconfig.md",
6363
"Gradient accumulation" => "rule_author/superpowers/gradient_accumulation.md",
64+
"Mutation Support (experimental)" => "rule_author/superpowers/mutation_support.md",
6465
],
6566
"Converting ZygoteRules.@adjoint to rrules" => "rule_author/converting_zygoterules.md",
6667
"Tips for making your package work with AD" => "rule_author/tips_for_packages.md",

docs/src/api.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Modules = [ChainRulesCore]
2020
Pages = [
2121
"tangent_types/abstract_zero.jl",
2222
"tangent_types/one.jl",
23-
"tangent_types/tangent.jl",
23+
"tangent_types/structural_tangent.jl",
2424
"tangent_types/thunks.jl",
2525
"tangent_types/abstract_tangent.jl",
2626
"tangent_types/notimplemented.jl",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Mutation Support
2+
3+
ChainRulesCore.jl offers experimental support for mutation, targeting use in forward mode AD.
4+
(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface)
5+
6+
!!! warning "Experimental"
7+
This page documents an experimental feature.
8+
Expect breaking changes in minor versions while this remains.
9+
It is not suitable for general use unless you are prepared to modify how you are using it each minor release.
10+
It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions.
11+
12+
13+
## `MutableTangent`
14+
The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place.
15+
It is required to be a structural tangent, having one tangent for each field of the primal object.
16+
17+
Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents.
18+
Just like not all `struct`s need to use `Tangent`s.
19+
Common examples away from this are natural tangent types like for arrays.
20+
However, if one is setting up to use a custom tangent type for this it is sufficiently off the beaten path that we can not provide much guidance.
21+
22+
## `zero_tangent`
23+
24+
The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value.
25+
The [`ZeroTangent`](@ref) type also does this.
26+
The difference is that [`zero_tangent`](@ref) is in general full structural tangent mirroring the structure of the primal.
27+
To be technical the promise of [`zero_tangent`](@ref) is that it will be a value that supports mutation.
28+
However, in practice[^1] this is achieved through in a structural tangent
29+
For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes.
30+
To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref).
31+
32+
33+
34+
It is also useful for reasons of type stability, since it forces a consistent type (generally a structural tangent) for any given primal type.
35+
For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not,
36+
and to process the output of `frule`s to convert [`ZeroTangent`](@ref) into corresponding [`zero_tangent`](@ref)s.
37+
38+
## Writing a frule for a mutating function
39+
It is relatively straight forward to write a frule for a mutating function.
40+
There are a few key points to follow:
41+
- There must be a mutable tangent input for every mutated primal input
42+
- When the primal value is changed, the corresponding change must be made to its tangent partner
43+
- When a value is returned, return its partnered tangent.
44+
- If (and only if) primal values alias, then their tangents must also alias.
45+
46+
### Example
47+
For example, consider the primal function with:
48+
1. takes two `Ref`s
49+
2. doubles the first one in place
50+
3. overwrites the second one's value with the literal 5.0
51+
4. returns the first one
52+
53+
54+
```julia
55+
function foo!(a::Base.RefValue, b::Base.RefValue)
56+
a[] *= 2
57+
b[] = 5.0
58+
return a
59+
end
60+
```
61+
62+
The frule for this would be:
63+
```julia
64+
function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue)
65+
@assertisa MutableTangent{typeof(a)}
66+
@assertisa MutableTangent{typeof(b)}
67+
68+
a[] *= 2
69+
.x *= 2 # `.x` is the field that lives behind RefValues
70+
71+
b[] = 5.0
72+
.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0`
73+
74+
return a, ȧ
75+
end
76+
```
77+
78+
Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct.
79+
80+
[^1]:
81+
Further, it is hard to achieve this promise of allowing mutation to be supported without returning a structural tangent.
82+
Except in the special case of where the struct is not mutable and has no nested fields that are mutable.

src/ChainRulesCore.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,26 @@ module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33
using Base.Meta
44
using LinearAlgebra
5-
using Compat: hasfield, hasproperty
5+
using Compat: hasfield, hasproperty, ismutabletype
66

77
export frule, rrule # core function
88
# rule configurations
99
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
1010
export frule_via_ad, rrule_via_ad
1111
# definition helper macros
1212
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
13-
export ProjectTo, canonicalize, unthunk # tangent operations
13+
export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations
1414
export add!!, is_inplaceable_destination # gradient accumulation operations
1515
export ignore_derivatives, @ignore_derivatives
1616
# tangents
17-
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
17+
export StructuralTangent, Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
1818

1919
include("debug_mode.jl")
2020

2121
include("tangent_types/abstract_tangent.jl")
22+
include("tangent_types/structural_tangent.jl")
2223
include("tangent_types/abstract_zero.jl")
2324
include("tangent_types/thunks.jl")
24-
include("tangent_types/tangent.jl")
2525
include("tangent_types/notimplemented.jl")
2626

2727
include("tangent_arithmetic.jl")

src/tangent_arithmetic.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Base.:+(x::NotImplemented, ::NotImplemented) = x
2020
Base.:*(x::NotImplemented, ::NotImplemented) = x
2121
LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x
2222
# `NotImplemented` always "wins" +
23-
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
23+
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :StructuralTangent, :Any)
2424
@eval Base.:+(x::NotImplemented, ::$T) = x
2525
@eval Base.:+(::$T, x::NotImplemented) = x
2626
end
@@ -33,7 +33,7 @@ for T in (:ZeroTangent, :NoTangent)
3333
@eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T()
3434
end
3535
# `NotImplemented` "wins" * and dot for other types
36-
for T in (:AbstractThunk, :Tangent, :Any)
36+
for T in (:AbstractThunk, :StructuralTangent, :Any)
3737
@eval Base.:*(x::NotImplemented, ::$T) = x
3838
@eval Base.:*(::$T, x::NotImplemented) = x
3939
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x
@@ -55,7 +55,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent()
5555
Base.:-(::NoTangent) = NoTangent()
5656
Base.:*(::NoTangent, ::NoTangent) = NoTangent()
5757
LinearAlgebra.dot(::NoTangent, ::NoTangent) = NoTangent()
58-
for T in (:AbstractThunk, :Tangent, :Any)
58+
for T in (:AbstractThunk, :StructuralTangent, :Any)
5959
@eval Base.:+(::NoTangent, b::$T) = b
6060
@eval Base.:+(a::$T, ::NoTangent) = a
6161
@eval Base.:-(::NoTangent, b::$T) = -b
@@ -95,7 +95,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
9595
Base.:-(::ZeroTangent) = ZeroTangent()
9696
Base.:*(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
9797
LinearAlgebra.dot(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
98-
for T in (:AbstractThunk, :Tangent, :Any)
98+
for T in (:AbstractThunk, :StructuralTangent, :Any)
9999
@eval Base.:+(::ZeroTangent, b::$T) = b
100100
@eval Base.:+(a::$T, ::ZeroTangent) = a
101101
@eval Base.:-(::ZeroTangent, b::$T) = -b
@@ -126,11 +126,11 @@ for T in (:Tangent, :Any)
126126
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
127127
end
128128

129-
function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P}
129+
function Base.:+(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P}
130130
data = elementwise_add(backing(a), backing(b))
131-
return Tangent{P,typeof(data)}(data)
131+
return StructuralTangent{P}(data)
132132
end
133-
function Base.:+(a::P, d::Tangent{P}) where {P}
133+
function Base.:+(a::P, d::StructuralTangent{P}) where {P}
134134
net_backing = elementwise_add(backing(a), backing(d))
135135
if debug_mode()
136136
try
@@ -143,14 +143,14 @@ function Base.:+(a::P, d::Tangent{P}) where {P}
143143
end
144144
end
145145
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
146-
Base.:+(a::Tangent{P}, b::P) where {P} = b + a
146+
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148-
Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent)
148+
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)
149149

150150
# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
151151
# In general one doesn't have to represent multiplications of 2 tangents
152152
# Only of a tangent and a scaling factor (generally `Real`)
153153
for T in (:Number,)
154-
@eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent)
155-
@eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent)
154+
@eval Base.:*(s::$T, tangent::StructuralTangent) = map(x -> s * x, tangent)
155+
@eval Base.:*(tangent::StructuralTangent, s::$T) = map(x -> x * s, tangent)
156156
end

src/tangent_types/abstract_zero.jl

+87
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,90 @@ arguments.
9191
```
9292
"""
9393
struct NoTangent <: AbstractZero end
94+
95+
"""
96+
zero_tangent(primal)
97+
98+
This returns an appropriate zero tangent suitable for accumulating tangents of the primal.
99+
For mutable composites types this is a structural [`MutableTangent`](@ref)
100+
For `Array`s, it is applied recursively for each element.
101+
For other types, in particular immutable types, we do not make promises beyond that it will be `iszero`
102+
and suitable for accumulating against.
103+
For types without a tangent space (e.g. singleton structs) this returns `NoTangent()`.
104+
In general, it is more likely to produce a structural tangent.
105+
106+
!!! warning Exprimental
107+
`zero_tangent`is an experimental feature, and is part of the mutation support featureset.
108+
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
109+
Exactly how it should be used (e.g. is it forward-mode only?)
110+
"""
111+
function zero_tangent end
112+
113+
zero_tangent(x::Number) = zero(x)
114+
115+
zero_tangent(::Type) = NoTangent()
116+
117+
function zero_tangent(x::MutableTangent{P}) where {P}
118+
zb = backing(zero_tangent(backing(x)))
119+
return MutableTangent{P}(zb)
120+
end
121+
122+
function zero_tangent(x::Tangent{P}) where {P}
123+
zb = backing(zero_tangent(backing(x)))
124+
return Tangent{P,typeof(zb)}(zb)
125+
end
126+
127+
@generated function zero_tangent(primal)
128+
fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero.
129+
zfield_exprs = map(fieldnames(primal)) do fname
130+
fval = :(
131+
if isdefined(primal, $(QuoteNode(fname)))
132+
zero_tangent(getfield(primal, $(QuoteNode(fname))))
133+
else
134+
# This is going to be potentially bad, but that's what they get for not giving us a primal
135+
# This will never me mutated inplace, rather it will alway be replaced with an actual value first
136+
ZeroTangent()
137+
end
138+
)
139+
Expr(:kw, fname, fval)
140+
end
141+
return if has_mutable_tangent(primal)
142+
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
143+
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
144+
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
145+
Expr(:kw, fname, fdef)
146+
end
147+
:($MutableTangent{$primal}(
148+
$(Expr(:tuple, Expr(:parameters, any_mask...))),
149+
$(Expr(:tuple, Expr(:parameters, zfield_exprs...))),
150+
))
151+
else
152+
:($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
153+
end
154+
end
155+
156+
zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...)
157+
158+
function zero_tangent(x::Array{P,N}) where {P,N}
159+
if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x)))
160+
return map(zero_tangent, x)
161+
end
162+
163+
# Now we need to handle nonfully assigned arrays
164+
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265
165+
y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...)
166+
@inbounds for n in eachindex(y)
167+
if isassigned(x, n)
168+
y[n] = zero_tangent(x[n])
169+
end
170+
end
171+
return y
172+
end
173+
174+
# Sad heauristic methods we need because of unassigned values
175+
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
176+
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))
177+
function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N}
178+
return Array{guess_zero_tangent_type(T),N}
179+
end
180+
guess_zero_tangent_type(T::Type) = Any

0 commit comments

Comments
 (0)