Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit bf06bbd

Browse files
committedMay 9, 2022
formatting
1 parent ddc0c3d commit bf06bbd

28 files changed

+194
-157
lines changed
 

‎.JuliaFormatter.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ whitespace_ops_in_indices = false
66
remove_extra_newlines = true
77
import_to_using = false
88
pipe_to_function_call = false
9-
short_to_long_function_def = false
9+
short_to_long_function_def = true
1010
always_use_return = false
1111
whitespace_in_kwargs = true
1212
annotate_untyped_fields_with_any = false

‎src/MeasureBase.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,4 @@ include("interface.jl")
131131

132132
using .Interface
133133

134-
end # module MeasureBase
134+
end # module MeasureBase

‎src/combinators/bind.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T}
3333
x = rand(rng, T, d.μ)
3434
y = rand(rng, T, d.k(x))
3535
return y
36-
end
36+
end

‎src/combinators/conditional.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct ConditionalMeasure{M,C} <: AbstractMeasure
2-
parent::M
2+
parent::M
33
constraint::C
44
end
55

@@ -37,10 +37,13 @@ condition(μ, constraint) = ConditionalMeasure(μ, constraint)
3737
# end
3838
# end
3939

40-
function Base.:|::ProductMeasure{NamedTuple{M,T}}, constraint::NamedTuple{N}) where {M,T,N}
41-
productmeasure(merge(marginals(μ),rmap(Dirac, constraint)))
40+
function Base.:|(
41+
μ::ProductMeasure{NamedTuple{M,T}},
42+
constraint::NamedTuple{N},
43+
) where {M,T,N}
44+
productmeasure(merge(marginals(μ), rmap(Dirac, constraint)))
4245
end
4346

4447
function Pretty.tile(d::ConditionalMeasure)
45-
Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.constraint), sep=" | ")
46-
end
48+
Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.constraint), sep = " | ")
49+
end

‎src/combinators/for.jl

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

‎src/combinators/likelihood.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ struct Likelihood{K,X} <: AbstractLikelihood
116116
k::K
117117
x::X
118118

119-
Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k,x)
120-
Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k,x)
119+
Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x)
120+
Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x)
121121
Likelihood(μ, x) = Likelihood(kernel(μ), x)
122122
end
123123

‎src/combinators/pointwise.jl

+3-5
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ struct PointwiseProductMeasure{P,L} <: AbstractMeasure
55
likelihood::L
66
end
77

8-
9-
10-
iterate(p::PointwiseProductMeasure, i=1) = iterate((p.prior, p.likelihood), i)
8+
iterate(p::PointwiseProductMeasure, i = 1) = iterate((p.prior, p.likelihood), i)
119

1210
function Pretty.tile(d::PointwiseProductMeasure)
13-
Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep="")
11+
Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep = "")
1412
end
1513

1614
(μ, ℓ) = pointwiseproduct(μ, ℓ)
@@ -24,7 +22,7 @@ function gentype(d::PointwiseProductMeasure)
2422
gentype(d.prior)
2523
end
2624

27-
@inbounds function insupport(d::PointwiseProductMeasure, p)
25+
@inbounds function insupport(d::PointwiseProductMeasure, p)
2826
μ, ℓ = d
2927
insupport(μ, p) && insupport(ℓ.k(p), ℓ.x)
3028
end

‎src/combinators/power.jl

+14-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ function Pretty.tile(μ::PowerMeasure)
2424
return Pretty.pair_layout(arg1, arg2; sep = " ^ ")
2525
end
2626

27-
function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure{M}) where {T, M<:AbstractMeasure}
27+
function Base.rand(
28+
rng::AbstractRNG,
29+
::Type{T},
30+
d::PowerMeasure{M},
31+
) where {T,M<:AbstractMeasure}
2832
map(CartesianIndices(d.axes)) do _
2933
rand(rng, T, d.parent)
3034
end
@@ -36,10 +40,10 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T}
3640
end
3741
end
3842

39-
@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T, N}
40-
a = axes(Fill{T, N}(x, sz))
43+
@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N}
44+
a = axes(Fill{T,N}(x, sz))
4145
A = typeof(a)
42-
PowerMeasure{T,A}(x,a)
46+
PowerMeasure{T,A}(x, a)
4347
end
4448

4549
marginals(d::PowerMeasure) = Fill(d.parent, d.axes)
@@ -61,7 +65,7 @@ params(d::PowerMeasure) = params(first(marginals(d)))
6165
# basemeasure(μ::PowerMeasure) = @inbounds basemeasure(first(μ.data))^size(μ.data)
6266

6367
@inline function basemeasure(d::PowerMeasure)
64-
basemeasure(d.parent) ^ d.axes
68+
basemeasure(d.parent)^d.axes
6569
end
6670

6771
@inline function logdensity_def(d::PowerMeasure{M}, x) where {M}
@@ -71,7 +75,10 @@ end
7175
end
7276
end
7377

74-
@inline function logdensity_def(d::PowerMeasure{M, Tuple{Base.OneTo{StaticInt{N}}}}, x) where {M,N}
78+
@inline function logdensity_def(
79+
d::PowerMeasure{M,Tuple{Base.OneTo{StaticInt{N}}}},
80+
x,
81+
) where {M,N}
7582
parent = d.parent
7683
sum(1:N) do j
7784
@inbounds logdensity_def(parent, x[j])
@@ -92,4 +99,4 @@ end
9299
# https://github.com/SciML/Static.jl/issues/36
93100
dynamic(insupport(p, xj))
94101
end
95-
end
102+
end

‎src/combinators/powerweighted.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@ end
77

88
logdensity_def(d::PowerWeightedMeasure, x) = d.exponent * logdensity_def(d.parent, x)
99

10-
basemeasure(d::PowerWeightedMeasure, x) = basemeasure(d.parent, x) d.exponent
10+
basemeasure(d::PowerWeightedMeasure, x) = basemeasure(d.parent, x)d.exponent
1111

12-
basemeasure(d::PowerWeightedMeasure) = basemeasure(d.parent) d.exponent
12+
basemeasure(d::PowerWeightedMeasure) = basemeasure(d.parent)d.exponent
1313

1414
function powerweightedmeasure(d, α)
1515
isone(α) && return d
1616
PowerWeightedMeasure(d, α)
1717
end
1818

19-
(d::AbstractMeasure) α = powerweightedmeasure(d, α)
19+
(d::AbstractMeasure)α = powerweightedmeasure(d, α)
2020

2121
insupport(d::PowerWeightedMeasure, x) = insupport(d.parent, x)
2222

@@ -29,9 +29,9 @@ function powerweightedmeasure(d::PowerWeightedMeasure, α)
2929
end
3030

3131
function powerweightedmeasure(d::WeightedMeasure, α)
32-
weightedmeasure*d.logweight, powerweightedmeasure(d.base, α))
32+
weightedmeasure * d.logweight, powerweightedmeasure(d.base, α))
3333
end
3434

3535
function Pretty.tile(d::PowerWeightedMeasure)
36-
Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.exponent), sep="")
37-
end
36+
Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.exponent), sep = "")
37+
end

‎src/combinators/product.jl

+34-17
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,23 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractProductMeasure) where
3131
_rand_product(rng, T, mar, eltype(mar))
3232
end
3333

34-
function _rand_product(rng::AbstractRNG, ::Type{T}, mar, ::Type{M}) where {T,M<:AbstractMeasure}
34+
function _rand_product(
35+
rng::AbstractRNG,
36+
::Type{T},
37+
mar,
38+
::Type{M},
39+
) where {T,M<:AbstractMeasure}
3540
map(mar) do dⱼ
3641
rand(rng, T, dⱼ)
3742
end
3843
end
3944

40-
function _rand_product(rng::AbstractRNG, ::Type{T}, mar::ReadonlyMappedArray, ::Type{M}) where {T,M<:AbstractMeasure}
45+
function _rand_product(
46+
rng::AbstractRNG,
47+
::Type{T},
48+
mar::ReadonlyMappedArray,
49+
::Type{M},
50+
) where {T,M<:AbstractMeasure}
4151
mappedarray(mar.data) do dⱼ
4252
rand(rng, T, mar.f(dⱼ))
4353
end |> collect
@@ -49,14 +59,17 @@ function _rand_product(rng::AbstractRNG, ::Type{T}, mar, ::Type{M}) where {T,M}
4959
end
5060
end
5161

52-
53-
function _rand_product(rng::AbstractRNG, ::Type{T}, mar::ReadonlyMappedArray, ::Type{M}) where {T,M}
62+
function _rand_product(
63+
rng::AbstractRNG,
64+
::Type{T},
65+
mar::ReadonlyMappedArray,
66+
::Type{M},
67+
) where {T,M}
5468
mappedarray(mar.data) do dⱼ
5569
rand(rng, mar.f(dⱼ))
5670
end |> collect
5771
end
5872

59-
6073
@inline function logdensity_def(d::AbstractProductMeasure, x)
6174
mapreduce(logdensity_def, +, marginals(d), x)
6275
end
@@ -70,12 +83,12 @@ end
7083
end
7184

7285
function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple}
73-
Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep="")
86+
Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = "")
7487
end
7588

7689
# For tuples, `mapreduce` has trouble with type inference
7790
@inline function logdensity_def(d::ProductMeasure{T}, x) where {T<:Tuple}
78-
ℓs = map(logdensity_def, marginals(d),x)
91+
ℓs = map(logdensity_def, marginals(d), x)
7992
sum(ℓs)
8093
end
8194

@@ -115,30 +128,35 @@ function basemeasure(μ::ProductMeasure{Base.Generator{I,F}}) where {I,F}
115128
_basemeasure(μ, B, static(Base.issingletontype(B)))
116129
end
117130

118-
119-
120131
function basemeasure::ProductMeasure{A}) where {T,A<:AbstractMappedArray{T}}
121132
B = Core.Compiler.return_type(basemeasure, Tuple{T})
122133
_basemeasure(μ, B, static(Base.issingletontype(B)))
123134
end
124135

125136
function _basemeasure::ProductMeasure, ::Type{B}, ::True) where {T,B}
126-
return instance(B) ^ axes(marginals(μ))
137+
return instance(B)^axes(marginals(μ))
127138
end
128139

129-
function _basemeasure::ProductMeasure{A}, ::Type{B}, ::False) where {T,A<:AbstractMappedArray{T},B}
140+
function _basemeasure(
141+
μ::ProductMeasure{A},
142+
::Type{B},
143+
::False,
144+
) where {T,A<:AbstractMappedArray{T},B}
130145
mar = marginals(μ)
131146
productmeasure(mappedarray(basemeasure, mar))
132147
end
133148

134-
function _basemeasure::ProductMeasure{Base.Generator{I,F}}, ::Type{B}, ::False) where {I,F,B}
149+
function _basemeasure(
150+
μ::ProductMeasure{Base.Generator{I,F}},
151+
::Type{B},
152+
::False,
153+
) where {I,F,B}
135154
mar = marginals(μ)
136155
productmeasure(Base.Generator(basekernel(mar.f), mar.iter))
137156
end
138157

139158
marginals::ProductMeasure) = μ.marginals
140159

141-
142160
testvalue(d::AbstractProductMeasure) = map(testvalue, marginals(d))
143161

144162
export
@@ -184,18 +202,17 @@ function _rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure, mar::AbstractArra
184202
rand!(rng, d, x)
185203
end
186204

187-
188205
@inline function insupport(d::AbstractProductMeasure, x::AbstractArray)
189206
mar = marginals(d)
190-
for (j,mj) in enumerate(mar)
207+
for (j, mj) in enumerate(mar)
191208
dynamic(insupport(mj, x[j])) || return false
192209
end
193210
return true
194211
end
195212

196213
@inline function insupport(d::AbstractProductMeasure, x)
197-
for (mj,xj) in zip(marginals(d), x)
214+
for (mj, xj) in zip(marginals(d), x)
198215
dynamic(insupport(mj, xj)) || return false
199216
end
200217
return true
201-
end
218+
end

‎src/combinators/restricted.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ function Pretty.quoteof(d::RestrictedMeasure)
1313
qf = Pretty.quoteof(d.predicate)
1414
qbase = Pretty.quoteof(d.base)
1515
:(RestrictedMeasure($qf, $qbase))
16-
end
16+
end

‎src/combinators/smart-constructors.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ productmeasure(tup::Tuple) = ProductMeasure(tup)
5353

5454
productmeasure(f, param_maps, pars) = ProductMeasure(kernel(f, param_maps), pars)
5555

56-
productmeasure(k::ParameterizedTransitionKernel, pars) = productmeasure(k.f, k.param_maps, pars)
57-
56+
function productmeasure(k::ParameterizedTransitionKernel, pars)
57+
productmeasure(k.f, k.param_maps, pars)
58+
end
5859

5960
function productmeasure(f::Returns{W}, ::typeof(identity), pars) where {W<:WeightedMeasure}
6061
= _logweight(f.value)
@@ -140,4 +141,3 @@ function kernel(::Type{M}; param_maps...) where {M}
140141
end
141142

142143
kernel(k::ParameterizedTransitionKernel) = k
143-

‎src/combinators/spikemixture.jl

+4-5
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ struct SpikeMixture{M,W,S} <: AbstractMeasure
88
s::S # spike weight
99
end
1010

11-
SpikeMixture(μ,w) = SpikeMixture(μ, w, static(1.0) - w)
11+
SpikeMixture(μ, w) = SpikeMixture(μ, w, static(1.0) - w)
1212

1313
function Pretty.tile::SpikeMixture)
14-
Pretty.list_layout(Pretty.tile.([μ.m, μ.w]), prefix="SpikeMixture")
14+
Pretty.list_layout(Pretty.tile.([μ.m, μ.w]), prefix = "SpikeMixture")
1515
end
1616

1717
# TODO: Should this base measure be local?
@@ -21,10 +21,9 @@ end
2121
SpikeMixture(basemeasure.m), static(1.0), static(1.0))
2222
end
2323

24-
2524
@inline function logdensity_def::SpikeMixture, x)
2625
if iszero(x)
27-
return log.s)
26+
return log.s)
2827
else
2928
return log.w) + logdensity_def.m, x)
3029
end
@@ -40,4 +39,4 @@ end
4039

4140
testvalue::SpikeMixture) = testvalue.m)
4241

43-
insupport::SpikeMixture, x) = dynamic(insupport.m, x)) || iszero(x)
42+
insupport::SpikeMixture, x) = dynamic(insupport.m, x)) || iszero(x)

‎src/combinators/transformedmeasure.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ function params(::AbstractTransformedMeasure) end
1212

1313
function paramnames(::AbstractTransformedMeasure) end
1414

15-
function parent(::AbstractTransformedMeasure) end
15+
function parent(::AbstractTransformedMeasure) end

‎src/combinators/weighted.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ _logweight(μ::WeightedMeasure) = μ.logweight
3131
basemeasure::AbstractWeightedMeasure) = μ.base
3232

3333
function Pretty.tile(d::WeightedMeasure)
34-
weight = round(exp(d.logweight), sigdigits=4)
35-
Pretty.pair_layout(Pretty.tile(weight), Pretty.tile(d.base), sep=" * ")
34+
weight = round(exp(d.logweight), sigdigits = 4)
35+
Pretty.pair_layout(Pretty.tile(weight), Pretty.tile(d.base), sep = " * ")
3636
end
3737

3838
function Base.:*(k::T, m::AbstractMeasure) where {T<:Number}
@@ -47,4 +47,4 @@ Base.:*(m::AbstractMeasure, k::Real) = k * m
4747

4848
gentype::WeightedMeasure) = gentype.base)
4949

50-
insupport::WeightedMeasure, x) = insupport.base, x)
50+
insupport::WeightedMeasure, x) = insupport.base, x)

‎src/density.jl

+23-14
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)
7878

7979
density_def::DensityMeasure, x) = densityof.f, x)
8080

81-
8281
export
8382

8483
"""
@@ -142,12 +141,12 @@ See also `logdensityof`.
142141
ℓ_0 = logdensity_def(μ, x)
143142
b_0 = μ
144143
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
145-
b_{i} = basemeasure(b_{i-1}, x)
146-
if b_{i} isa typeof(b_{i-1})
147-
return ℓ_{i-1}
144+
b_{i} = basemeasure(b_{i - 1}, x)
145+
if b_{i} isa typeof(b_{i - 1})
146+
return ℓ_{i - 1}
148147
end
149148
ℓ_{i} = let Δℓ_{i} = logdensity_def(b_{i}, x)
150-
ℓ_{i-1} + Δℓ_{i}
149+
ℓ_{i - 1} + Δℓ_{i}
151150
end
152151
end
153152
return ℓ_10
@@ -162,7 +161,7 @@ export logdensity_rel
162161
@inline return_type(f, args::Tuple) = Core.Compiler.return_type(f, Tuple{typeof.(args)...})
163162

164163
unstatic(::Type{T}) where {T} = T
165-
unstatic(::Type{StaticFloat64{X}}) where X = Float64
164+
unstatic(::Type{StaticFloat64{X}}) where {X} = Float64
166165

167166
"""
168167
logdensity_rel(m1, m2, x)
@@ -173,7 +172,12 @@ known to be in the support of both, it can be more efficient to call
173172
`unsafe_logdensity_rel`.
174173
"""
175174
@inline function logdensity_rel::M, ν::N, x::X) where {M,N,X}
176-
T = unstatic(promote_type(return_type(logdensity_def, (μ, x)), return_type(logdensity_def, (ν, x))))
175+
T = unstatic(
176+
promote_type(
177+
return_type(logdensity_def, (μ, x)),
178+
return_type(logdensity_def, (ν, x)),
179+
),
180+
)
177181
inμ = insupport(μ, x)
178182
inν = insupport(ν, x)
179183
inμ || return convert(T, ifelse(inν, -Inf, NaN))
@@ -191,7 +195,7 @@ known to be in the support of both `m1` and `m2`.
191195
See also `logdensity_rel`.
192196
"""
193197
@inline function unsafe_logdensity_rel::M, ν::N, x::X) where {M,N,X}
194-
if static_hasmethod(logdensity_def, Tuple{M, N, X})
198+
if static_hasmethod(logdensity_def, Tuple{M,N,X})
195199
return logdensity_def(μ, ν, x)
196200
end
197201
μs = basemeasure_sequence(μ)
@@ -221,19 +225,24 @@ function logdensity_def(μ::T, ν::T, x) where {T}
221225
if μ === ν
222226
return zero(logdensity_def(μ, x))
223227
else
224-
return logdensity_def(μ,x) - logdensity_def(ν, x)
228+
return logdensity_def(μ, x) - logdensity_def(ν, x)
225229
end
226230
end
227231

228-
@generated function _logdensity_rel(μs::Tμ, νs::Tν, ::Tuple{StaticInt{M},StaticInt{N}}, x::X) where {Tμ, Tν,M,N,X}
232+
@generated function _logdensity_rel(
233+
μs::Tμ,
234+
νs::Tν,
235+
::Tuple{StaticInt{M},StaticInt{N}},
236+
x::X,
237+
) where {Tμ,Tν,M,N,X}
229238
= schema(Tμ)
230239
= schema(Tν)
231-
232-
q = quote
240+
241+
q = quote
233242
$(Expr(:meta, :inline))
234243
= logdensity_def(μs[$M], νs[$N], x)
235244
end
236-
245+
237246
for i in 1:M-1
238247
push!(q.args, :(Δℓ = logdensity_def(μs[$i], x)))
239248
# push!(q.args, :(println("Adding", Δℓ)))
@@ -267,4 +276,4 @@ basemeasure(rebase(μ, ν)) == ν
267276
density(rebase(μ, ν)) == 𝒹(μ,ν)
268277
```
269278
"""
270-
rebase(μ, ν) = (𝒹(μ,ν), ν)
279+
rebase(μ, ν) = (𝒹(μ, ν), ν)

‎src/domains.jl

+12-9
Original file line numberDiff line numberDiff line change
@@ -74,30 +74,34 @@ end
7474

7575
export ZeroSet
7676

77-
struct ZeroSet{F, G} <: AbstractDomain
77+
struct ZeroSet{F,G} <: AbstractDomain
7878
f::F
7979
∇f::G
8080
end
8181

8282
# Based on some quick tests, but may need some adjustment
8383
Base.in(x::AbstractArray{T}, z::ZeroSet) where {T} = abs(z.f(x)) < ldexp(eps(float(T)), 6)
8484

85-
8685
###########################################################
8786
# CodimOne
8887

8988
export CodimOne
9089

9190
abstract type CodimOne <: AbstractDomain end
9291

93-
function tangentat(a::CodimOne, b::CodimOne, x::AbstractArray{T}; tol=ldexp(eps(float(T)), 6)) where {T}
92+
function tangentat(
93+
a::CodimOne,
94+
b::CodimOne,
95+
x::AbstractArray{T};
96+
tol = ldexp(eps(float(T)), 6),
97+
) where {T}
9498
# Sometimes you get lucky
9599
a == b && return true
96100

97101
# Get the normal vectors
98102
g1 = a.∇f(x)
99103
g2 = b.∇f(x)
100-
104+
101105
# See if one is a multiple of the other
102106
one(T) - Statistics.corm(g1, zero(T), g2, zero(T)) < tol
103107
end
@@ -110,13 +114,13 @@ export Simplex
110114

111115
struct Simplex <: CodimOne end
112116

113-
function zeroset(::Simplex)
117+
function zeroset(::Simplex)
114118
f(x::AbstractArray{T}) where {T} = sum(x) - one(T)
115119
∇f(x::AbstractArray{T}) where {T} = Fill(one(T), size(x))
116120
ZeroSet(f, ∇f)
117121
end
118122

119-
function Base.in(x::AbstractArray{T}, ::Simplex) where {T}
123+
function Base.in(x::AbstractArray{T}, ::Simplex) where {T}
120124
all((zero(eltype(x))), x) || return false
121125
return x zeroset(Simplex())
122126
end
@@ -128,15 +132,14 @@ projectto!(x, ::Simplex) = normalize!(x, 1)
128132

129133
struct Sphere <: CodimOne end
130134

131-
function zeroset(::Sphere)
135+
function zeroset(::Sphere)
132136
f(x::AbstractArray{T}) where {T} = sum(xⱼ -> xⱼ^2, x) - one(T)
133137
∇f(x::AbstractArray{T}) where {T} = x
134138
ZeroSet(f, ∇f)
135139
end
136140

137-
function Base.in(x::AbstractArray{T}, ::Sphere) where {T}
141+
function Base.in(x::AbstractArray{T}, ::Sphere) where {T}
138142
return x zeroset(Sphere())
139143
end
140144

141145
projectto!(x, ::Sphere) = normalize!(x, 2)
142-

‎src/interface.jl

+3-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Reexport
44

55
@reexport using MeasureBase
66

7-
using MeasureBase:basemeasure_depth, proxy
7+
using MeasureBase: basemeasure_depth, proxy
88
using MeasureBase: insupport, basemeasure_sequence, commonbase
99

1010
export test_interface
@@ -31,10 +31,10 @@ function test_interface(μ::M) where {M}
3131
μ = $μ
3232
@testset "" begin
3333
μ = $μ
34-
34+
3535
###########################################################################
3636
# basemeasure_depth
37-
static_depth = @inferred basemeasure_depth(μ)
37+
static_depth = @inferred basemeasure_depth(μ)
3838

3939
dynamic_depth = dynamic_basemeasure_depth(μ)
4040

@@ -58,6 +58,3 @@ function test_interface(μ::M) where {M}
5858
end
5959

6060
end # module Interface
61-
62-
63-

‎src/kernel.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@ struct ParameterizedTransitionKernel{F,N,T} <: AbstractTransitionKernel
55
f::F
66
param_maps::NamedTuple{N,T}
77

8-
ParameterizedTransitionKernel(::Type{F}, param_maps::NamedTuple{N,T}) where {F,N,T} =
8+
function ParameterizedTransitionKernel(
9+
::Type{F},
10+
param_maps::NamedTuple{N,T},
11+
) where {F,N,T}
912
new{Type{F},N,T}(F, param_maps)
10-
ParameterizedTransitionKernel(f::F, param_maps::NamedTuple{N,T}) where {F,N,T} =
13+
end
14+
function ParameterizedTransitionKernel(f::F, param_maps::NamedTuple{N,T}) where {F,N,T}
1115
new{F,N,T}(f, param_maps)
16+
end
1217
end
1318

1419
"""
@@ -32,7 +37,6 @@ another common use of this term.
3237
"""
3338
function kernel end
3439

35-
3640
# kernel(Normal) do x
3741
# (μ=x,σ=x^2)
3842
# end
@@ -51,7 +55,6 @@ function (k::ParameterizedTransitionKernel)(x::Tuple)
5155
k.f(NamedTuple{k.param_maps}(x))
5256
end
5357

54-
5558
"""
5659
For any `k::TransitionKernel`, `basekernel` is expected to satisfy
5760
```
@@ -72,7 +75,6 @@ basekernel(k::ParameterizedTransitionKernel) = kernel(basekernel(k.f), k.param_m
7275

7376
basekernel(f::Returns) = Returns(basemeasure(f.value))
7477

75-
7678
function Base.show(io::IO, μ::AbstractTransitionKernel)
7779
io = IOContext(io, :compact => true)
7880
Pretty.pprint(io, μ)
@@ -86,4 +88,4 @@ end
8688

8789
const kleisli = kernel
8890

89-
export kleisli
91+
export kleisli

‎src/parameterized.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828
#
2929

3030
function kernel(::Type{P}) where {N,P<:ParameterizedMeasure{N}}
31-
C = constructorof(P)
31+
C = constructorof(P)
3232
_kernel(C, Val(N))
3333
end
3434

@@ -37,7 +37,7 @@ end
3737
C(NamedTuple{N,T}(args))::C{N,T}
3838
end
3939

40-
@inline function f(arg::T) where {T}
40+
@inline function f(arg::T) where {T}
4141
C(NamedTuple{N,Tuple{T}}((arg,)))::C{N,Tuple{T}}
4242
end
4343

@@ -113,8 +113,6 @@ See also `params`
113113
"""
114114
function paramnames end
115115

116-
117-
118116
paramnames::M) where {M} = paramnames(M)
119117

120118
paramnames(::Type{PM}) where {N,PM<:ParameterizedMeasure{N}} = N

‎src/primitives/counting.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ Base.show(io::IO, d::Counting) = print(io, "Counting(", d.support, ")")
3030

3131
insupport::Counting, x) = x μ.support
3232

33-
insupport::Counting{T}, x) where {T<:Type} = x isa μ.support
33+
insupport::Counting{T}, x) where {T<:Type} = x isa μ.support

‎src/primitives/lebesgue.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct Lebesgue{T} <: AbstractMeasure
1515
end
1616

1717
function Pretty.tile::Lebesgue)
18-
Pretty.list_layout([Pretty.tile.support)]; prefix=:Lebesgue)
18+
Pretty.list_layout([Pretty.tile.support)]; prefix = :Lebesgue)
1919
end
2020

2121
gentype(::Lebesgue) = Float64
@@ -39,4 +39,4 @@ insupport(::Lebesgue{RealNumbers}, ::Real) = true
3939

4040
logdensity_def(::LebesgueMeasure, ::CountingMeasure, x) = -Inf
4141

42-
logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf
42+
logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf

‎src/schema.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Taken from https://github.com/cscherrer/NestedTuples.jl/blob/cd298fd1e5f7e571701a6fee916d2d47c19f32f5/src/typelevel.jl
22

3-
ntkeys(::Type{NamedTuple{K,V}}) where {K, V} = K
4-
ntvaltype(::Type{NamedTuple{K,V}}) where {K, V} = V
3+
ntkeys(::Type{NamedTuple{K,V}}) where {K,V} = K
4+
ntvaltype(::Type{NamedTuple{K,V}}) where {K,V} = V
55

66
"""
77
schema(::Type)
@@ -16,19 +16,19 @@ Example:
1616
"""
1717
function schema end
1818

19-
schema(::NamedTuple{(), Tuple{}}) = NamedTuple()
20-
schema(::Type{NamedTuple{(), Tuple{}}}) = NamedTuple()
19+
schema(::NamedTuple{(),Tuple{}}) = NamedTuple()
20+
schema(::Type{NamedTuple{(),Tuple{}}}) = NamedTuple()
2121

22-
function schema(NT::Type{NamedTuple{names, T}}) where {names, T}
22+
function schema(NT::Type{NamedTuple{names,T}}) where {names,T}
2323
return NamedTuple{ntkeys(NT)}(schema(ntvaltype(NT)))
2424
end
2525

26-
function schema(TT::Type{T}) where {T <: Tuple}
26+
function schema(TT::Type{T}) where {T<:Tuple}
2727
return schema.(Tuple(TT.types))
2828
end
2929

30-
schema(t::T) where {T <: Tuple} = schema(T)
30+
schema(t::T) where {T<:Tuple} = schema(T)
3131

32-
schema(t::T) where {T <: NamedTuple} = schema(T)
32+
schema(t::T) where {T<:NamedTuple} = schema(T)
3333

34-
schema(T) = T
34+
schema(T) = T

‎src/splat.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ struct Splat{F}
22
f::F
33
end
44

5-
function (s::Splat{F})(x) where F
5+
function (s::Splat{F})(x) where {F}
66
s.f(x...)
77
end
88

‎src/utils.jl

+19-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ end
88
showparams(io::IO, ::EmptyNamedTuple) = print(io, "()")
99
showparams(io::IO, nt::NamedTuple) = print(io, nt)
1010

11-
1211
export testvalue
1312
testvalue::AbstractMeasure) = testvalue(basemeasure(μ))
1413
testvalue(::Type{T}) where {T} = zero(T)
@@ -47,7 +46,6 @@ end
4746

4847
# issingletontype(@nospecialize(t)) = (@_pure_meta; isa(t, DataType) && isdefined(t, :instance))
4948

50-
5149
# @generated function instance(::Type{T}) where {T}
5250
# return getfield(T, :instance)::T
5351
# end
@@ -61,9 +59,9 @@ export basemeasure_depth
6159
@inline function basemeasure_depth::M) where {M}
6260
b_0 = μ
6361
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
64-
b_{i} = basemeasure(b_{i-1})
65-
if b_{i} isa typeof(b_{i-1})
66-
return static(i-1)
62+
b_{i} = basemeasure(b_{i - 1})
63+
if b_{i} isa typeof(b_{i - 1})
64+
return static(i - 1)
6765
end
6866
end
6967
return static(10)
@@ -79,16 +77,20 @@ measure of the previous term, and with no repeated entries.
7977
b_1 = μ
8078
done = false
8179
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
82-
b_{i+1} = if done nothing else basemeasure(b_{i}) end
83-
if b_{i+1} isa typeof(b_{i})
80+
b_{i + 1} = if done
81+
nothing
82+
else
83+
basemeasure(b_{i})
84+
end
85+
if b_{i + 1} isa typeof(b_{i})
8486
done = true
85-
b_{i+1} = nothing
87+
b_{i + 1} = nothing
8688
end
8789
end
8890
return filter(!isnothing, Base.Cartesian.@ntuple 10 b)
8991
end
9092

91-
commonbase(μ, ν) = commonbase(μ, ν, Any)
93+
commonbase(μ, ν) = commonbase(μ, ν, Any)
9294

9395
"""
9496
commonbase(μ, ν, T) -> Tuple{StaticInt{i}, StaticInt{j}}
@@ -107,9 +109,12 @@ end
107109
m = schema(M)
108110
n = schema(N)
109111

110-
sols = Iterators.filter(((i,j),) -> static_hasmethod(logdensity_def, Tuple{m[i], n[j], T}), Iterators.product(1:length(m), 1:length(n)))
112+
sols = Iterators.filter(
113+
((i, j),) -> static_hasmethod(logdensity_def, Tuple{m[i],n[j],T}),
114+
Iterators.product(1:length(m), 1:length(n)),
115+
)
111116
isempty(sols) && return :(nothing)
112-
minsol = static.(argmin(((i,j),) -> i+j, sols))
117+
minsol = static.(argmin(((i, j),) -> i + j, sols))
113118
quote
114119
$minsol
115120
end
@@ -131,15 +136,14 @@ end
131136
return true
132137
end
133138

134-
135139
allequal(x::AbstractArray) = allequal(identity, x)
136140

137141
rmap(f, x) = f(x)
138142

139143
function rmap(f, t::Tuple)
140-
map(x -> rmap(f,x), t)
144+
map(x -> rmap(f, x), t)
141145
end
142146

143147
function rmap(f, nt::NamedTuple{N,T}) where {N,T}
144-
NamedTuple{N}(map(x -> rmap(f,x), values(nt)))
145-
end
148+
NamedTuple{N}(map(x -> rmap(f, x), values(nt)))
149+
end

‎test/combinators/superpose.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ using MeasureBase: superpose
1515
μs = SuperpositionMeasure([μ, ν])
1616
@test μs isa SuperpositionMeasure{<:AbstractVector{<:AbstractMeasure}}
1717
@test_throws ErrorException density_def(μs, 0)
18-
@test basemeasure(μs).components == SuperpositionMeasure([CountingMeasure(), CountingMeasure()]).components
18+
@test basemeasure(μs).components ==
19+
SuperpositionMeasure([CountingMeasure(), CountingMeasure()]).components
1920

2021
μ2 = μ + μ
2122
@test μ2 isa WeightedMeasure

‎test/combinators/weighted.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using MeasureBase: _logweight, weightedmeasure, WeightedMeasure
1313
@test μ isa WeightedMeasure
1414
@test _logweight(μ) == log(w)
1515
@test _logweight(w * μ) == 2 * log(w)
16-
@test rand(MersenneTwister(123), μ) == rand(MersenneTwister(123), μ₀)
16+
@test rand(MersenneTwister(123), μ) == rand(MersenneTwister(123), μ₀)
1717
x = rand()
1818
@test logdensity_def(μ, x) == log(w)
1919
@test logdensityof(μ, x) == logdensityof(μ₀, x)

‎test/runtests.jl

+28-30
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ d = ∫exp(x -> -x^2, Lebesgue(ℝ))
2020
# return (x,y)
2121
# end
2222

23-
2423
test_measures = [
2524
# Chain(x -> Normal(μ=x), Normal(μ=0.0))
2625
# For(3) do j
@@ -53,8 +52,8 @@ testbroken_measures = [
5352
for μ in test_measures
5453
@info "testing "
5554
test_interface(μ)
56-
test_interface ^ 3)
57-
test_interface ^ (3,2))
55+
test_interface^3)
56+
test_interface^(3, 2))
5857
test_interface(5 * μ)
5958
# test_interface(SpikeMixture(μ, 0.2))
6059
end
@@ -105,10 +104,9 @@ end
105104
# end
106105
# end
107106

108-
109107
@testset "powers" begin
110-
@test logdensityof(Lebesgue() ^ 3, 2) == logdensityof(Lebesgue() ^ (3,), 2)
111-
@test logdensityof(Lebesgue() ^ 3, 2) == logdensityof(Lebesgue() ^ (3,1), (2,0))
108+
@test logdensityof(Lebesgue()^3, 2) == logdensityof(Lebesgue()^(3,), 2)
109+
@test logdensityof(Lebesgue()^3, 2) == logdensityof(Lebesgue()^(3, 1), (2, 0))
112110
end
113111

114112
@testset "Half" begin
@@ -176,41 +174,41 @@ end
176174
end
177175

178176
@testset "logdensity_rel" begin
179-
@test logdensity_rel(Dirac(0.0)+Lebesgue(), Dirac(1.0), 0.0) == Inf
180-
@test logdensity_rel(Dirac(0.0)+Lebesgue(), Dirac(1.0), 1.0) == -Inf
181-
@test logdensity_rel(Dirac(0.0)+Lebesgue(), Dirac(1.0), 2.0) == Inf
182-
@test logdensity_rel(Dirac(0.0)+Lebesgue(), Dirac(0.0), 0.0) == 0.0
183-
@test logdensity_rel(Dirac(0.0)+Lebesgue(), Dirac(0.0), 1.0) == Inf
184-
@test logdensity_rel(Dirac(0.0)+Lebesgue(), Lebesgue(), 0.0) == Inf
185-
@test logdensity_rel(Dirac(0.0)+Lebesgue(), Lebesgue(), 1.0) == 0.0
186-
187-
@test logdensity_rel(Dirac(1.0), Dirac(0.0)+Lebesgue(), 0.0) == -Inf
188-
@test logdensity_rel(Dirac(1.0), Dirac(0.0)+Lebesgue(), 1.0) == Inf
189-
@test logdensity_rel(Dirac(1.0), Dirac(0.0)+Lebesgue(), 2.0) == -Inf
190-
@test logdensity_rel(Dirac(0.0), Dirac(0.0)+Lebesgue(), 0.0) == 0.0
191-
@test logdensity_rel(Dirac(0.0), Dirac(0.0)+Lebesgue(), 1.0) == -Inf
192-
@test logdensity_rel(Lebesgue(), Dirac(0.0)+Lebesgue(), 0.0) == -Inf
193-
@test logdensity_rel(Lebesgue(), Dirac(0.0)+Lebesgue(), 1.0) == 0.0
194-
177+
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 0.0) == Inf
178+
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 1.0) == -Inf
179+
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 2.0) == Inf
180+
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(0.0), 0.0) == 0.0
181+
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(0.0), 1.0) == Inf
182+
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Lebesgue(), 0.0) == Inf
183+
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Lebesgue(), 1.0) == 0.0
184+
185+
@test logdensity_rel(Dirac(1.0), Dirac(0.0) + Lebesgue(), 0.0) == -Inf
186+
@test logdensity_rel(Dirac(1.0), Dirac(0.0) + Lebesgue(), 1.0) == Inf
187+
@test logdensity_rel(Dirac(1.0), Dirac(0.0) + Lebesgue(), 2.0) == -Inf
188+
@test logdensity_rel(Dirac(0.0), Dirac(0.0) + Lebesgue(), 0.0) == 0.0
189+
@test logdensity_rel(Dirac(0.0), Dirac(0.0) + Lebesgue(), 1.0) == -Inf
190+
@test logdensity_rel(Lebesgue(), Dirac(0.0) + Lebesgue(), 0.0) == -Inf
191+
@test logdensity_rel(Lebesgue(), Dirac(0.0) + Lebesgue(), 1.0) == 0.0
192+
195193
@test isnan(logdensity_rel(Dirac(0), Dirac(1), 2))
196194
end
197195

198196
@testset "Density measures and Radon-Nikodym" begin
199197
x = randn()
200198
f(x) = x^2
201-
@test logdensityof(𝒹(∫exp(f, Lebesgue()), Lebesgue()),x ) f(x)
199+
@test logdensityof(𝒹(∫exp(f, Lebesgue()), Lebesgue()), x) f(x)
202200

203201
let f = 𝒹(∫exp(x -> x^2, Lebesgue()), Lebesgue())
204-
@test logdensityof(f,x) x^2
202+
@test logdensityof(f, x) x^2
205203
end
206204

207-
# let d = ∫exp(log𝒹(Cauchy(), Normal()), Normal())
208-
# @test logdensity_def(d, x) ≈ logdensity_def(Cauchy(), x)
209-
# end
205+
# let d = ∫exp(log𝒹(Cauchy(), Normal()), Normal())
206+
# @test logdensity_def(d, x) ≈ logdensity_def(Cauchy(), x)
207+
# end
210208

211-
# let f = log𝒹(∫exp(x -> x^2, Normal()), Normal())
212-
# @test f(x) ≈ x^2
213-
# end
209+
# let f = log𝒹(∫exp(x -> x^2, Normal()), Normal())
210+
# @test f(x) ≈ x^2
211+
# end
214212
end
215213

216214
include("combinators/weighted.jl")

0 commit comments

Comments
 (0)
Please sign in to comment.