Skip to content

Commit f3dd452

Browse files
devmotionmcabbott
andauthored
Add rules for LogExpFunctions (#69)
* Add rules for LogExpFunctions * Add keyword argument to `diffrules` * Add tests * Add more doctests * Remove whitespace * More stable tests * Fix `log2mexp` tests * Update CI and cancel builds for old commits in PRs * Update src/api.jl Co-authored-by: Michael Abbott <[email protected]> * Change to `filter_modules` and update docs * Add preview of docs * Fix doctests Co-authored-by: Michael Abbott <[email protected]>
1 parent 2db3a81 commit f3dd452

File tree

8 files changed

+199
-69
lines changed

8 files changed

+199
-69
lines changed

.github/workflows/TagBot.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
name: TagBot
22
on:
3-
schedule:
4-
- cron: 0 * * * *
53
issue_comment:
64
types:
75
- created

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
*.jl.cov
22
*.jl.*.cov
33
*.jl.mem
4+
/Manifest.toml

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
33
version = "1.3.1"
44

55
[deps]
6+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
67
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
78
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
910

1011
[compat]
12+
LogExpFunctions = "0.3"
1113
NaNMath = "0.3"
1214
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
1315
julia = "1"

docs/make.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
using Documenter, DiffRules
22

3+
DocMeta.setdocmeta!(
4+
DiffRules,
5+
:DocTestSetup,
6+
:(using DiffRules);
7+
recursive=true,
8+
)
9+
310
makedocs(modules=[DiffRules],
4-
doctest = false,
511
sitename = "DiffRules",
612
pages = ["Documentation" => "index.md"],
713
format = Documenter.HTML(
814
prettyurls = get(ENV, "CI", nothing) == "true"
915
),
16+
strict=true,
17+
checkdocs=:exports,
1018
)
1119

1220
deploydocs(; repo="github.com/JuliaDiff/DiffRules.jl", push_preview=true)

src/DiffRules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ __precompile__()
22

33
module DiffRules
44

5+
import LogExpFunctions
6+
57
include("api.jl")
68
include("rules.jl")
79

src/api.jl

Lines changed: 105 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ interpolated wherever they are used on the RHS.
1616
1717
Note that differentiation rules are purely symbolic, so no type annotations should be used.
1818
19-
Examples:
20-
21-
@define_diffrule Base.cos(x) = :(-sin(\$x))
22-
@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
23-
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
19+
# Examples
2420
21+
```julia
22+
@define_diffrule Base.cos(x) = :(-sin(\$x))
23+
@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
24+
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
25+
```
2526
"""
2627
macro define_diffrule(def)
2728
@assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side"
@@ -50,19 +51,18 @@ interpolated into the returned expression.
5051
In the `n`-ary case, an `n`-tuple of expressions will be returned where the `i`th expression
5152
is the derivative of `f` w.r.t the `i`th argument.
5253
53-
Examples:
54-
55-
julia> DiffRules.diffrule(:Base, :sin, 1)
56-
:(cos(1))
54+
# Examples
5755
58-
julia> DiffRules.diffrule(:Base, :sin, :x)
59-
:(cos(x))
56+
```jldoctest
57+
julia> DiffRules.diffrule(:Base, :sin, 1)
58+
:(cos(1))
6059
61-
julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
62-
:(cos(x * y ^ 2))
60+
julia> DiffRules.diffrule(:Base, :sin, :x)
61+
:(cos(x))
6362
64-
julia> DiffRules.diffrule(:Base, :^, :(x + 2), :c)
65-
(:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2)))
63+
julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
64+
:(cos(x * y ^ 2))
65+
```
6666
"""
6767
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_DIFFRULES[M,f,length(args)](args...)
6868

@@ -74,41 +74,109 @@ otherwise.
7474
7575
Here, `arity` refers to the number of arguments accepted by `f`.
7676
77-
Examples:
77+
# Examples
7878
79-
julia> DiffRules.hasdiffrule(:Base, :sin, 1)
80-
true
79+
```jldoctest
80+
julia> DiffRules.hasdiffrule(:Base, :sin, 1)
81+
true
8182
82-
julia> DiffRules.hasdiffrule(:Base, :sin, 2)
83-
false
83+
julia> DiffRules.hasdiffrule(:Base, :sin, 2)
84+
false
8485
85-
julia> DiffRules.hasdiffrule(:Base, :-, 1)
86-
true
86+
julia> DiffRules.hasdiffrule(:Base, :-, 1)
87+
true
8788
88-
julia> DiffRules.hasdiffrule(:Base, :-, 2)
89-
true
89+
julia> DiffRules.hasdiffrule(:Base, :-, 2)
90+
true
9091
91-
julia> DiffRules.hasdiffrule(:Base, :-, 3)
92-
false
92+
julia> DiffRules.hasdiffrule(:Base, :-, 3)
93+
false
94+
```
9395
"""
9496
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_DIFFRULES, (M, f, arity))
9597

98+
# show a deprecation warning if `filter_modules` in `diffrules()` is specified implicitly
99+
# we use a custom singleton to figure out if the keyword argument was set explicitly
100+
struct DefaultFilterModules end
101+
102+
function deprecated_modules(modules)
103+
return if modules isa DefaultFilterModules
104+
Base.depwarn(
105+
"the implicit keyword argument " *
106+
"`filter_modules=(:Base, :SpecialFunctions, :NaNMath)` in `diffrules()` is " *
107+
"deprecated and will be changed to `filter_modules=nothing` in an upcoming " *
108+
"breaking release of DiffRules (i.e., `diffrules()` will return all rules " *
109+
"defined in DiffRules)",
110+
:diffrules,
111+
)
112+
(:Base, :SpecialFunctions, :NaNMath)
113+
else
114+
modules
115+
end
116+
end
117+
96118
"""
97-
diffrules()
119+
diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))
98120
99-
Return a list of keys that can be used to access all defined differentiation rules.
121+
Return a list of keys that can be used to access all defined differentiation rules for
122+
modules in `filter_modules`.
100123
101124
Each key is of the form `(M::Symbol, f::Symbol, arity::Int)`.
102-
103-
Here, `arity` refers to the number of arguments accepted by `f`.
104-
105-
Examples:
106-
107-
julia> first(DiffRules.diffrules())
108-
(:Base, :asind, 1)
109-
125+
Here, `arity` refers to the number of arguments accepted by `f` and `M` is one of the
126+
modules in `filter_modules`.
127+
128+
To include all rules, specify `filter_modules = nothing`.
129+
130+
!!! note
131+
Calling `diffrules()` with the implicit default keyword argument `filter_modules`
132+
does *not* return all rules defined by this package but rather only rules for the
133+
packages for which DiffRules 1.0 provided rules. This is done in order to not to
134+
break downstream packages that assumed this list would never change.
135+
It is planned to change `diffrules()` to return all rules, i.e., to use the
136+
default keyword argument `filter_modules=nothing`, in an upcoming breaking release
137+
of DiffRules.
138+
139+
# Examples
140+
141+
```jldoctest
142+
julia> first(DiffRules.diffrules())
143+
(:Base, :log2, 1)
144+
```
145+
146+
If you call `diffrules()`, only rules for Base, SpecialFunctions, and
147+
NaNMath are returned but no rules for LogExpFunctions:
148+
```jldoctest
149+
julia> any(M === :LogExpFunctions for (M, _, _) in DiffRules.diffrules())
150+
false
151+
```
152+
153+
If you set `filter_modules=nothing`, all rules defined in DiffRules are
154+
returned and in particular also rules for LogExpFunctions:
155+
```jldoctest
156+
julia> any(
157+
M === :LogExpFunctions
158+
for (M, _, _) in DiffRules.diffrules(; filter_modules=nothing)
159+
)
160+
true
161+
```
162+
163+
If you set `filter_modules=(:Base,)` only rules for functions in Base are
164+
returned:
165+
```jldoctest
166+
julia> all(M === :Base for (M, _, _) in DiffRules.diffrules(; filter_modules=(:Base,)))
167+
true
168+
```
110169
"""
111-
diffrules() = keys(DEFINED_DIFFRULES)
170+
function diffrules(; filter_modules=DefaultFilterModules())
171+
modules = deprecated_modules(filter_modules)
172+
return if modules === nothing
173+
keys(DEFINED_DIFFRULES)
174+
else
175+
Iterators.filter(keys(DEFINED_DIFFRULES)) do (M, _, _)
176+
return M in modules
177+
end
178+
end
179+
end
112180

113181
# For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a
114182
# `Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will

src/rules.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,30 @@ end
232232
:(ifelse(($y > $x) | (signbit($y) < signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($y), zero($y))))
233233
@define_diffrule NaNMath.min(x, y) = :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), one($x), zero($x)), ifelse(isnan($x), zero($x), one($x)))),
234234
:(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($x), zero($x))))
235+
236+
###################
237+
# LogExpFunctions #
238+
###################
239+
240+
# unary
241+
@define_diffrule LogExpFunctions.xlogx(x) = :(1 + log($x))
242+
@define_diffrule LogExpFunctions.logistic(x) = :(z = LogExpFunctions.logistic($x); z * (1 - z))
243+
@define_diffrule LogExpFunctions.logit(x) = :(inv($x * (1 - $x)))
244+
@define_diffrule LogExpFunctions.log1psq(x) = :(2 * $x / (1 + $x^2))
245+
@define_diffrule LogExpFunctions.log1pexp(x) = :(LogExpFunctions.logistic($x))
246+
@define_diffrule LogExpFunctions.log1mexp(x) = :(-exp($x - LogExpFunctions.log1mexp($x)))
247+
@define_diffrule LogExpFunctions.log2mexp(x) = :(-exp($x - LogExpFunctions.log2mexp($x)))
248+
@define_diffrule LogExpFunctions.logexpm1(x) = :(exp($x - LogExpFunctions.logexpm1($x)))
249+
250+
# binary
251+
@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y)
252+
@define_diffrule LogExpFunctions.logaddexp(x, y) =
253+
:(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y)))
254+
@define_diffrule LogExpFunctions.logsubexp(x, y) =
255+
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)),
256+
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z))
257+
258+
# only defined in LogExpFunctions >= 0.3.2
259+
if isdefined(LogExpFunctions, :xlog1py)
260+
@define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :($x / (1 + $y))
261+
end

test/runtests.jl

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,58 @@
1-
if VERSION < v"0.7-"
2-
using Base.Test
3-
srand(1)
4-
else
5-
using Test
6-
import Random
7-
Random.seed!(1)
8-
end
9-
import SpecialFunctions, NaNMath
101
using DiffRules
2+
using Test
113

4+
import SpecialFunctions, NaNMath, LogExpFunctions
5+
import Random
6+
Random.seed!(1)
127

138
function finitediff(f, x)
149
ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x))
1510
return (f(x + ϵ) - f(x - ϵ)) /+ ϵ)
1611
end
1712

13+
@testset "DiffRules" begin
14+
@testset "check rules" begin
1815

1916
non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)]
2017

21-
for (M, f, arity) in DiffRules.diffrules()
18+
for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
2219
(M, f, arity) non_numeric_arg_functions && continue
2320
if arity == 1
2421
@test DiffRules.hasdiffrule(M, f, 1)
2522
deriv = DiffRules.diffrule(M, f, :goo)
26-
modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0
23+
modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)
24+
1.0
25+
elseif f === :log1mexp
26+
-1.0
27+
elseif f === :log2mexp
28+
-0.5
29+
else
30+
0.0
31+
end
2732
@eval begin
28-
goo = rand() + $modifier
29-
@test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05)
30-
# test for 2pi functions
31-
if "mod2pi" == string($M.$f)
32-
goo = 4pi + $modifier
33-
@test NaN === $deriv
33+
let
34+
goo = rand() + $modifier
35+
@test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05)
36+
# test for 2pi functions
37+
if "mod2pi" == string($M.$f)
38+
goo = 4pi + $modifier
39+
@test NaN === $deriv
40+
end
3441
end
3542
end
3643
elseif arity == 2
3744
@test DiffRules.hasdiffrule(M, f, 2)
3845
derivs = DiffRules.diffrule(M, f, :foo, :bar)
3946
@eval begin
40-
foo, bar = rand(1:10), rand()
41-
dx, dy = $(derivs[1]), $(derivs[2])
42-
if !(isnan(dx))
43-
@test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05)
44-
end
45-
if !(isnan(dy))
46-
@test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05)
47+
let
48+
foo, bar = rand(1:10), rand()
49+
dx, dy = $(derivs[1]), $(derivs[2])
50+
if !(isnan(dx))
51+
@test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05)
52+
end
53+
if !(isnan(dy))
54+
@test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05)
55+
end
4756
end
4857
end
4958
elseif arity == 3
@@ -72,14 +81,29 @@ derivs = DiffRules.diffrule(:Base, :rem2pi, :x, :y)
7281
for xtype in [:Float64, :BigFloat, :Int64]
7382
for mode in [:RoundUp, :RoundDown, :RoundToZero, :RoundNearest]
7483
@eval begin
75-
x = $xtype(rand(1 : 10))
76-
y = $mode
77-
dx, dy = $(derivs[1]), $(derivs[2])
78-
@test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05)
79-
@test isnan(dy)
84+
let
85+
x = $xtype(rand(1 : 10))
86+
y = $mode
87+
dx, dy = $(derivs[1]), $(derivs[2])
88+
@test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05)
89+
@test isnan(dy)
90+
end
8091
end
8192
end
8293
end
94+
end
95+
96+
@testset "diffrules" begin
97+
rules = @test_deprecated(DiffRules.diffrules())
98+
@test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath))
99+
100+
rules = DiffRules.diffrules(; filter_modules=nothing)
101+
@test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath, :LogExpFunctions))
102+
103+
rules = DiffRules.diffrules(; filter_modules=(:Base, :LogExpFunctions))
104+
@test Set(M for (M, _, _) in rules) == Set((:Base, :LogExpFunctions))
105+
end
106+
end
83107

84108
# Test ifelse separately as first argument is boolean
85109
#=

0 commit comments

Comments
 (0)