Skip to content

Commit 0163b35

Browse files
authored
Remove Zygote; fix #2504
1 parent 1397d69 commit 0163b35

File tree

7 files changed

+3
-51
lines changed

7 files changed

+3
-51
lines changed

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au
8888
|:----------------- |:------------------------------------ |:---------------------- |
8989
| `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend |
9090
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
91-
| `AutoZygote` | [`ADTypes.AutoZygote`](@extref) | Zygote.jl backend |
9291
| `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend |
9392

9493
### Debugging

src/Turing.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ export @model, # modelling
106106
externalsampler,
107107
AutoForwardDiff, # ADTypes
108108
AutoReverseDiff,
109-
AutoZygote,
110109
AutoMooncake,
111110
setprogress!, # debugging
112111
Flat,

src/essential/Essential.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution
1111
using AdvancedVI
1212
using StatsFuns: logsumexp, softmax
1313
@reexport using DynamicPPL
14-
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake
14+
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake
1515

1616
using AdvancedPS: AdvancedPS
1717

@@ -20,7 +20,6 @@ include("container.jl")
2020
export @model,
2121
@varname,
2222
AutoForwardDiff,
23-
AutoZygote,
2423
AutoReverseDiff,
2524
AutoMooncake,
2625
@logprob_str,

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3636
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3737
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3838
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
39-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4039

4140
[compat]
4241
AbstractMCMC = "5"
@@ -75,5 +74,4 @@ StableRNGs = "1"
7574
StatsBase = "0.33, 0.34"
7675
StatsFuns = "0.9.5, 1"
7776
TimerOutputs = "0.5"
78-
Zygote = "0.5.4, 0.6"
7977
julia = "1.10"

test/essential/ad.jl

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using ReverseDiff
1111
using Test: @test, @testset
1212
using Turing
1313
using Turing: SampleFromPrior
14-
using Zygote
1514

1615
function test_model_ad(model, f, syms::Vector{Symbol})
1716
# Set up VI.
@@ -87,20 +86,6 @@ end
8786
vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext()
8887
)
8988
x = map(x -> Float64(x), vi[SampleFromPrior()])
90-
91-
zygoteℓ = LogDensityProblemsAD.ADgradient(Turing.AutoZygote(), ℓ)
92-
if isdefined(Base, :get_extension)
93-
@test zygoteℓ isa
94-
Base.get_extension(
95-
LogDensityProblemsAD, :LogDensityProblemsADZygoteExt
96-
).ZygoteGradientLogDensity
97-
else
98-
@test zygoteℓ isa
99-
LogDensityProblemsAD.LogDensityProblemsADZygoteExt.ZygoteGradientLogDensity
100-
end
101-
@test zygoteℓ.===
102-
∇E2 = LogDensityProblems.logdensity_and_gradient(zygoteℓ, x)[2]
103-
@test sort(∇E2) grad_FWAD atol = 1e-9
10489
end
10590

10691
@testset "general AD tests" begin
@@ -135,11 +120,10 @@ end
135120

136121
test_model_ad(wishart_ad(), logp3, [:v])
137122
end
138-
@testset "Simplex Zygote and ReverseDiff (with and without caching) AD" begin
123+
@testset "Simplex ReverseDiff (with and without caching) AD" begin
139124
@model function dir()
140125
return theta ~ Dirichlet(1 ./ fill(4, 4))
141126
end
142-
sample(dir(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
143127
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
144128
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=true)), 1000)
145129
end
@@ -149,14 +133,12 @@ end
149133
end
150134

151135
sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
152-
sample(wishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
153136

154137
@model function invwishart()
155138
return theta ~ InverseWishart(4, Matrix{Float64}(I, 4, 4))
156139
end
157140

158141
sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
159-
sample(invwishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
160142
end
161143
@testset "Hessian test" begin
162144
@model function tst(x, ::Type{TV}=Vector{Float64}) where {TV}

test/test_utils/ad_utils.jl

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using Mooncake: Mooncake
88
using Test: Test
99
using Turing: Turing
1010
using Turing: DynamicPPL
11-
using Zygote: Zygote
1211

1312
export ADTypeCheckContext, adbackends
1413

@@ -31,9 +30,6 @@ const eltypes_by_adtype = Dict(
3130
ReverseDiff.TrackedVector,
3231
),
3332
Turing.AutoMooncake => (Mooncake.CoDual,),
34-
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the
35-
# two by element type. However, we have other checks for Zygote, see check_adtype.
36-
Turing.AutoZygote => (Zygote.Dual,),
3733
)
3834

3935
"""
@@ -90,7 +86,6 @@ For instance, evaluating a model with
9086
would throw an error if within the model a type associated with e.g. ReverseDiff was
9187
encountered.
9288
93-
As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
9489
"""
9590
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
9691
DynamicPPL.AbstractContext
@@ -134,20 +129,9 @@ end
134129
135130
Check that the element types in `vi` are compatible with the ADType of `context`.
136131
137-
When Zygote is being used, we also more explicitly check that `adtype(context)` is
138-
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't
139-
discriminate between the two based on element type alone. This function will still fail to
140-
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead.
141-
142-
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or
143-
`WrongADBackendError` if Zygote is used unexpectedly.
132+
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
144133
"""
145134
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
146-
Zygote.hook(vi) do _
147-
if !(adtype(context) <: Turing.AutoZygote)
148-
throw(WrongADBackendError(Turing.AutoZygote, adtype(context)))
149-
end
150-
end
151135

152136
valids = valid_eltypes(context)
153137
for val in vi[:]

test/test_utils/test_utils.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using ReverseDiff: ReverseDiff
77
using Test: @test, @testset, @test_throws
88
using Turing: Turing
99
using Turing: DynamicPPL
10-
using Zygote: Zygote
1110

1211
# Check that the ADTypeCheckContext works as expected.
1312
@testset "ADTypeCheckContext" begin
@@ -16,20 +15,12 @@ using Zygote: Zygote
1615
adtypes = (
1716
Turing.AutoForwardDiff(),
1817
Turing.AutoReverseDiff(),
19-
Turing.AutoZygote(),
2018
# TODO: Mooncake
2119
# Turing.AutoMooncake(config=nothing),
2220
)
2321
for actual_adtype in adtypes
2422
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
2523
for expected_adtype in adtypes
26-
if (
27-
actual_adtype == Turing.AutoForwardDiff() &&
28-
expected_adtype == Turing.AutoZygote()
29-
)
30-
# TODO(mhauru) We are currently unable to check this case.
31-
continue
32-
end
3324
contextualised_tm = DynamicPPL.contextualize(
3425
tm, ADTypeCheckContext(expected_adtype, tm.context)
3526
)

0 commit comments

Comments
 (0)