-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapi.jl
90 lines (72 loc) · 2.07 KB
/
api.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
using HDPHMM
using Test
using Distributions
using Missings
import ConjugatePriors: NormalInverseChisq
function getprior()
tp = TransitionDistributionPrior(
Gamma(1, 1/0.001),
Gamma(1, 1/0.001),
Beta(50, 1)
)
op = DPMMObservationModelPrior{Normal}(
NormalInverseChisq(1, 1, 1, 1),
Gamma(1, 0.5),
)
BlockedSamplerPrior(1.0, tp, op)
end
@testset "Sampler API" begin
L, LP = 10, 5
sampler = BlockedSampler(L, LP)
prior = getprior()
data = rand(2520)
config = MCConfig(
chains = 2,
verb = true
)
@test_nowarn sample(sampler, prior, data, config = config)
end
@testset "Sampler API - Init" begin
L, LP = 10, 5
sampler = BlockedSampler(L, LP)
prior = getprior()
data = allowmissing(rand(1000))
@test_nowarn sample(sampler, prior, data, config = MCConfig(init = KMeansInit(L)))
@test_nowarn sample(sampler, prior, data, config = MCConfig(init = BinsInit(L)))
@test_nowarn sample(sampler, prior, data, config = MCConfig(init = FixedInit(ones(length(data)))))
end
@testset "Sampler API - Missings" begin
L, LP = 10, 5
sampler = BlockedSampler(L, LP)
prior = getprior()
data = allowmissing(rand(1000))
data[100:110] .= missing
HDPHMM.enablemissing(1.0)
@test_nowarn sample(sampler, prior, data)
end
@testset "Chain" begin
L, LP = 10, 5
sampler = BlockedSampler(L, LP)
prior = getprior()
data = rand(1000)
chains = sample(sampler, prior, data)
@test_nowarn select_hamming(chains[1])
end
@testset "Cleaning API" begin
index = [730, 247]
data = [2., 1.]
index_, data_ = resample_interval(index, data, 240)
@test length(index_) == 3
@test index_ == [247, 487, 730]
@test length(data_) == 3
@test data_[[1,3]] == [1., 2.]
@test data_[2] === missing
index = [1, 3, 5]
data = [1., 2., 3.]
index_, data_ = resample_interval(index, data, 2)
@test index_ == index
@test data_ == data
index_, data_ = resample_interval([], [], 100)
@test index_ == []
@test data_ == []
end