Skip to content

Fix tests #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,3 @@ StaticArrays = "1"
StatsBase = "0.33"
StructArrays = "0.6"
julia = "1.7"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ForwardDiff", "LinearAlgebra", "Random", "Symbolics"]
22 changes: 22 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRules = "1.5"
ChainRulesCore = "1.2"
Combinatorics = "1"
StaticArrays = "1"
StatsBase = "0.33"
StructArrays = "0.6"
julia = "1.7"
32 changes: 18 additions & 14 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()
# Check characteristic of exp rule
@variables ω α β γ δ ϵ ζ η
(x1, c1) = ∂⃖{3}()(exp, ω)
@test simplify(x1 == exp(ω)).val
@test isequal(simplify(x1), simplify(exp(ω)))
((_, x2), c2) = c1(α)
@test simplify(x2 == α*exp(ω)).val
@test isequal(simplify(x2), simplify(α*exp(ω)))
(x3, c3) = c2(ZeroTangent(), β)
@test simplify(x3 == β*exp(ω)).val
@test isequal(simplify(x3), simplify(β*exp(ω)))
((_, x4), c4) = c3(γ)
@test simplify(x4 == exp(ω)*(γ + (α*β))).val
@test isequal(simplify(x4), simplify(exp(ω)*(γ + (α*β))))
(x5, c5) = c4(ZeroTangent(), δ)
@test simplify(x5 == δ*exp(ω)).val
@test isequal(simplify(x5), simplify(δ*exp(ω)))
((_, x6), c6) = c5(ϵ)
@test simplify(x6 == ϵ*exp(ω) + α*δ*exp(ω)).val
@test isequal(simplify(x6), simplify(ϵ*exp(ω) + α*δ*exp(ω)))
(x7, c7) = c6(ZeroTangent(), ζ)
@test simplify(x7 == ζ*exp(ω) + β*δ*exp(ω)).val
@test isequal(simplify(x7), simplify(ζ*exp(ω) + β*δ*exp(ω)))
(_, x8) = c7(η)
@test simplify(x8 == (η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω)).val
@test isequal(simplify(x8), simplify((η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω)))

# Minimal 2-nd order forward smoke test
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
Expand Down Expand Up @@ -123,10 +123,12 @@ let var"'" = Diffractor.PrimeDerivativeFwd
# Integration tests
@test recursive_sin'(1.0) == cos(1.0)
@test recursive_sin''(1.0) == -sin(1.0)
@test recursive_sin'''(1.0) == -cos(1.0)
@test recursive_sin''''(1.0) == sin(1.0)
@test recursive_sin'''''(1.0) == cos(1.0)
@test recursive_sin''''''(1.0) == -sin(1.0)
# Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}
# should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}.
Comment on lines +126 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JuliaDiff/ChainRulesCore.jl#503 should address this, but the integration test still seems to be failing. Probably fine mark these as broken for now, until I figure that out.

@test_broken recursive_sin'''(1.0) == -cos(1.0)
@test_broken recursive_sin''''(1.0) == sin(1.0)
@test_broken recursive_sin'''''(1.0) == cos(1.0)
@test_broken recursive_sin''''''(1.0) == -sin(1.0)

# Test the special rules for sin/cos/exp
@test sin''''''(1.0) == -sin(1.0)
Expand All @@ -148,7 +150,7 @@ end
@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0]

const fwd = Diffractor.PrimeDerivativeFwd
const bwd = Diffractor.PrimeDerivativeFwd
const bwd = Diffractor.PrimeDerivativeBack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops.


function f_broadcast(a)
l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]]
Expand Down Expand Up @@ -186,7 +188,9 @@ end
# Issue #27 - Mixup in lifting of getfield
let var"'" = bwd
@test (x->x^5)''(1.0) == 20.
@test (x->x^5)'''(1.0) == 60.
@test (x->(x*x)*(x*x)*x)''' == 60.
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
@test_broken (x->x^5)'''(1.0) == 60.
end

# Issue #38 - Splatting arrays
Expand Down