-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinterface.jl
139 lines (112 loc) · 3.98 KB
/
interface.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
module Interface
using Reexport
@reexport using MeasureBase
using MeasureBase: basemeasure_depth, proxy, istrue
using MeasureBase: insupport, basemeasure_sequence, commonbase
using MeasureBase: transport_to, NoTransport
using DensityInterface: logdensityof
using InverseFunctions: inverse
using ChangesOfVariables: with_logabsdet_jacobian
using Tricks: static_hasmethod
using IntervalSets: Interval
export test_interface
export test_transport
export test_smf
export basemeasure_depth
export proxy
export insupport
export basemeasure_sequence
export commonbase
using Test
function dynamic_basemeasure_depth(μ::M) where {M}
if static_hasmethod(proxy, Tuple{M})
π = proxy(μ)
if static_hasmethod(basemeasure, Tuple{typeof(π)})
basemeasure(π) == basemeasure(μ) && return dynamic_basemeasure_depth(π)
end
end
β = basemeasure(μ)
depth = 0
while μ ≠ β
(μ, β) = (β, basemeasure(β))
depth += 1
end
return depth
end
function test_interface(μ::M) where {M}
@eval begin
μ = $μ
@testset "$μ" begin
μ = $μ
@test AbstractMeasure(μ) isa AbstractMeasure
###########################################################################
# basemeasure_depth
static_depth = @inferred basemeasure_depth(μ)
dynamic_depth = dynamic_basemeasure_depth(μ)
if static_depth > dynamic_depth
@warn "basemeasure_depth($μ) greater than requirement, could add some overhead"
end
@test static_depth ≥ dynamic_depth
###########################################################################
# testvalue, logdensityof
x = @inferred testvalue(Float64, μ)
β = @inferred basemeasure(μ, x)
@test logdensity_def(μ, x) ≈ logdensity_rel(μ, β, x)
ℓμ = @inferred logdensityof(μ, x)
ℓβ = @inferred logdensityof(β, x)
@test ℓμ ≈ logdensity_def(μ, x) + ℓβ
@test logdensity_def(μ, testvalue(Float64, μ)) isa Real
end
end
end
function test_transport(ν, μ)
supertype(x) = Any
supertype(x::Real) = Real
supertype(x::AbstractArray{<:Real,N}) where {N} = AbstractArray{<:Real,N}
structisapprox(a, b) = isapprox(a, b)
function structisapprox(a::NTuple{N,Any}, b::NTuple{N,Any}) where {N}
all(map(structisapprox, a, b))
end
function structisapprox(a::NamedTuple{names}, b::NamedTuple{names}) where {names}
all(map(structisapprox, values(a), values(b)))
end
@testset "transport_to $μ to $ν" begin
x = rand(μ)
@test !(@inferred(transport_to(ν, μ)(x)) isa NoTransport)
f = transport_to(ν, μ)
y = f(x)
@test structisapprox(@inferred(inverse(f)(y)), x)
@test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{supertype(y),Real}
@test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{supertype(x),Real}
y2, ladj_fwd = with_logabsdet_jacobian(f, x)
x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y)
@test structisapprox(x, x2)
@test structisapprox(y, y2)
@test isapprox(ladj_fwd, -ladj_inv, atol = 1e-10)
@test ladj_fwd ≈ logdensityof(μ, x) - logdensityof(ν, y)
end
end
function test_smf(μ, n = 100)
@testset "smf($μ)" begin
# Get `n` sorted uniforms in O(n) time
p = rand(n)
p .+= 0:n-1
p .*= inv(n)
F(x) = smf(μ, x)
Finv(p) = invsmf(μ, p)
@assert issorted(p)
x = invsmf.(μ, p)
@test issorted(x)
@test all(istrue ∘ insupport(μ), x)
@test all((Finv ∘ F).(x) .≈ x)
for j in 1:n
a = rand()
b = rand()
a, b = minmax(a, b)
x = Finv(a)
y = Finv(b)
@test μ(Interval{:open,:closed}(x, y)) ≈ (F(y) - F(x))
end
end
end
end # module Interface