Skip to content

Commit 47214f8

Browse files
authored
Remove adjoint for fill and fix tests (#203)
* Remove adjoint for `fill` and fix Zygote tests * Bump version * Fix some more problems * Extension of #203: Fix deprecations in test (#204) * Fix deprecations * Improve CI (AD): cancel builds and no coverage * Improve CI (Others): cancel builds and no coverage * Change parameters to avoid issues with `xlogy` * Tracker does not like Diagonal(Fill(...)) * Unify CI * Fix tests * Update test structure and separate AD better * Fix tests * Relax type constraint * Simplify Zygote tests and use CR * Improve test design * Fix typo * Fix typo * Replace `unpack` with `_to_vec` * Fix tests (a bit) * Fix another test problem * Fix `_to_vec` * Fix handling of broken Zygote tests * Workarounds for `rand_tangent` * Improvements and fixes for Julia 1.3 * Remove Zygote test hack
1 parent af2ea73 commit 47214f8

File tree

12 files changed

+617
-483
lines changed

12 files changed

+617
-483
lines changed
Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
name: AD tests
1+
name: CI
22

33
on:
44
push:
55
branches:
66
- master
77
pull_request:
88

9+
concurrency:
10+
# Skip intermediate builds: always.
11+
# Cancel intermediate builds: only if it is a pull request build.
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
14+
915
jobs:
1016
test:
1117
runs-on: ${{ matrix.os }}
12-
continue-on-error: ${{ matrix.version == 'nightly' }}
1318
strategy:
1419
matrix:
1520
version:
@@ -19,7 +24,8 @@ jobs:
1924
- ubuntu-latest
2025
arch:
2126
- x64
22-
AD:
27+
group:
28+
- Others
2329
- ForwardDiff
2430
- Tracker
2531
- ReverseDiff
@@ -28,27 +34,42 @@ jobs:
2834
- version: '1'
2935
os: macOS-latest
3036
arch: x64
31-
AD: ForwardDiff
37+
group: Others
38+
- version: '1'
39+
os: macOS-latest
40+
arch: x64
41+
group: ForwardDiff
3242
- version: '1'
3343
os: macOS-latest
3444
arch: x64
35-
AD: Tracker
45+
group: Tracker
3646
- version: '1'
3747
os: macOS-latest
3848
arch: x64
39-
AD: ReverseDiff
49+
group: ReverseDiff
4050
- version: '1'
4151
os: macOS-latest
4252
arch: x64
43-
AD: Zygote
53+
group: Zygote
4454
steps:
4555
- uses: actions/checkout@v2
4656
- uses: julia-actions/setup-julia@v1
4757
with:
4858
version: ${{ matrix.version }}
4959
arch: ${{ matrix.arch }}
60+
- uses: actions/cache@v1
61+
env:
62+
cache-name: cache-artifacts
63+
with:
64+
path: ~/.julia/artifacts
65+
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
66+
restore-keys: |
67+
${{ runner.os }}-test-${{ env.cache-name }}-
68+
${{ runner.os }}-test-
69+
${{ runner.os }}-
5070
- uses: julia-actions/julia-buildpkg@latest
5171
- uses: julia-actions/julia-runtest@latest
72+
with:
73+
coverage: false
5274
env:
53-
GROUP: AD
54-
AD: ${{ matrix.AD }}
75+
GROUP: ${{ matrix.group }}

.github/workflows/Others.yml

Lines changed: 0 additions & 34 deletions
This file was deleted.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.31"
3+
version = "0.6.32"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# DistributionsAD.jl
22

3-
[![AD tests](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/AD.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/AD.yml?query=branch%3Amaster)
4-
[![Other tests](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/Others.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/Others.yml?query=branch%3Amaster)
3+
[![CI](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/CI.yml?query=branch%3Amaster)
54

65
This package defines the necessary functions to enable automatic differentiation (AD) of the `logpdf` function from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) using the packages [Tracker.jl](https://github.com/FluxML/Tracker.jl), [Zygote.jl](https://github.com/FluxML/Zygote.jl), [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl). The goal of this package is to make the output of `logpdf` differentiable wrt all continuous parameters of a distribution as well as the random variable in the case of continuous distributions.
76

src/flatten.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ const flattened_dists = [ Bernoulli,
2727
Poisson,
2828
Skellam,
2929
Arcsine,
30-
Beta,
3130
BetaPrime,
3231
Biweight,
3332
Cauchy,
@@ -40,11 +39,9 @@ const flattened_dists = [ Bernoulli,
4039
Exponential,
4140
FDist,
4241
Frechet,
43-
Gamma,
4442
GeneralizedExtremeValue,
4543
GeneralizedPareto,
4644
Gumbel,
47-
#InverseGamma,
4845
InverseGaussian,
4946
Kolmogorov,
5047
Laplace,
@@ -54,8 +51,6 @@ const flattened_dists = [ Bernoulli,
5451
LogitNormal,
5552
LogNormal,
5653
Normal,
57-
#NormalCanon,
58-
#NormalInverseGaussian,
5954
Pareto,
6055
PGeneralizedGaussian,
6156
Rayleigh,
@@ -64,8 +59,6 @@ const flattened_dists = [ Bernoulli,
6459
TriangularDist,
6560
Triweight,
6661
TuringUniform,
67-
#Truncated,
68-
#VonMises,
6962
]
7063
for T in flattened_dists
7164
@eval toflatten(::$T) = true

src/zygote.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
# Zygote fill has issues with non-numbers
2-
ZygoteRules.@adjoint function fill(x::T, dims...) where {T}
3-
return ZygoteRules.pullback(x, dims...) do x, dims...
4-
return reshape([x for i in 1:prod(dims)], dims)
5-
end
6-
end
7-
8-
91
## Uniform ##
102

113
ZygoteRules.@adjoint function Distributions.Uniform(args...)

0 commit comments

Comments
 (0)