Skip to content

Commit 789f498

Browse files
yebaigithub-actions[bot]mhaurupenelopeysm
authored
Remove Zygote (#2505)
* Remove `Zygote`; fix #2504 * Update test/test_utils/ad_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add HISTORY.md entry about removing support for Zygote --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: Markus Hauru <[email protected]>
1 parent f904f99 commit 789f498

File tree

7 files changed

+8
-33
lines changed

7 files changed

+8
-33
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66

77
0.37 removes the old Gibbs constructors deprecated in 0.36.
88

9+
### Remove Zygote support
10+
11+
Zygote is no longer officially supported as an automatic differentiation backend, and `AutoZygote` is no longer exported. You can continue to use Zygote by importing `AutoZygote` from ADTypes and it may well continue to work, but it is no longer tested and no effort will be expended to fix it if something breaks.
12+
13+
[Mooncake](https://github.com/compintell/Mooncake.jl/) is the recommended replacement for Zygote.
14+
915
### DynamicPPL 0.35
1016

1117
Turing.jl v0.37 uses DynamicPPL v0.35, which brings with it several breaking changes:

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au
8888
|:----------------- |:------------------------------------ |:---------------------- |
8989
| `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend |
9090
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
91-
| `AutoZygote` | [`ADTypes.AutoZygote`](@extref) | Zygote.jl backend |
9291
| `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend |
9392

9493
### Debugging

src/Turing.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ export @model, # modelling
106106
externalsampler,
107107
AutoForwardDiff, # ADTypes
108108
AutoReverseDiff,
109-
AutoZygote,
110109
AutoMooncake,
111110
setprogress!, # debugging
112111
Flat,

src/essential/Essential.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution
1111
using AdvancedVI
1212
using StatsFuns: logsumexp, softmax
1313
@reexport using DynamicPPL
14-
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake
14+
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake
1515

1616
using AdvancedPS: AdvancedPS
1717

@@ -20,7 +20,6 @@ include("container.jl")
2020
export @model,
2121
@varname,
2222
AutoForwardDiff,
23-
AutoZygote,
2423
AutoReverseDiff,
2524
AutoMooncake,
2625
@logprob_str,

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3636
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3737
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3838
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
39-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4039

4140
[compat]
4241
AbstractMCMC = "5"
@@ -75,5 +74,4 @@ StableRNGs = "1"
7574
StatsBase = "0.33, 0.34"
7675
StatsFuns = "0.9.5, 1"
7776
TimerOutputs = "0.5"
78-
Zygote = "0.5.4, 0.6"
7977
julia = "1.10"

test/test_utils/ad_utils.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using Mooncake: Mooncake
88
using Test: Test
99
using Turing: Turing
1010
using Turing: DynamicPPL
11-
using Zygote: Zygote
1211

1312
export ADTypeCheckContext, adbackends
1413

@@ -31,9 +30,6 @@ const eltypes_by_adtype = Dict(
3130
ReverseDiff.TrackedVector,
3231
),
3332
Turing.AutoMooncake => (Mooncake.CoDual,),
34-
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the
35-
# two by element type. However, we have other checks for Zygote, see check_adtype.
36-
Turing.AutoZygote => (Zygote.Dual,),
3733
)
3834

3935
"""
@@ -90,7 +86,6 @@ For instance, evaluating a model with
9086
would throw an error if within the model a type associated with e.g. ReverseDiff was
9187
encountered.
9288
93-
As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
9489
"""
9590
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
9691
DynamicPPL.AbstractContext
@@ -134,21 +129,9 @@ end
134129
135130
Check that the element types in `vi` are compatible with the ADType of `context`.
136131
137-
When Zygote is being used, we also more explicitly check that `adtype(context)` is
138-
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't
139-
discriminate between the two based on element type alone. This function will still fail to
140-
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead.
141-
142-
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or
143-
`WrongADBackendError` if Zygote is used unexpectedly.
132+
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
144133
"""
145134
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
146-
Zygote.hook(vi) do _
147-
if !(adtype(context) <: Turing.AutoZygote)
148-
throw(WrongADBackendError(Turing.AutoZygote, adtype(context)))
149-
end
150-
end
151-
152135
valids = valid_eltypes(context)
153136
for val in vi[:]
154137
valtype = typeof(val)

test/test_utils/test_utils.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using ReverseDiff: ReverseDiff
77
using Test: @test, @testset, @test_throws
88
using Turing: Turing
99
using Turing: DynamicPPL
10-
using Zygote: Zygote
1110

1211
# Check that the ADTypeCheckContext works as expected.
1312
@testset "ADTypeCheckContext" begin
@@ -16,20 +15,12 @@ using Zygote: Zygote
1615
adtypes = (
1716
Turing.AutoForwardDiff(),
1817
Turing.AutoReverseDiff(),
19-
Turing.AutoZygote(),
2018
# TODO: Mooncake
2119
# Turing.AutoMooncake(config=nothing),
2220
)
2321
for actual_adtype in adtypes
2422
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
2523
for expected_adtype in adtypes
26-
if (
27-
actual_adtype == Turing.AutoForwardDiff() &&
28-
expected_adtype == Turing.AutoZygote()
29-
)
30-
# TODO(mhauru) We are currently unable to check this case.
31-
continue
32-
end
3324
contextualised_tm = DynamicPPL.contextualize(
3425
tm, ADTypeCheckContext(expected_adtype, tm.context)
3526
)

0 commit comments

Comments
 (0)