Skip to content

Commit 822a0aa

Browse files
seabbsSamuelBrand1
andauthored
Issue 471: Add TransfromObservationModel (#496)
* add TransfromObservationModel * review unit tests * improve tests * add benchmarks: * add ref to softplus * Update TransformObservationModel.jl * Update TransformObservationModel.jl * Update TransformObservationModel.jl * Update EpiAware/src/EpiObsModels/modifiers/TransformObservationModel.jl Co-authored-by: Samuel Brand <[email protected]> * use negbin so there is sampled parameter * Choose a model with a parameter to sample and document the gotcha * catch a whitespace --------- Co-authored-by: Samuel Brand <[email protected]> Co-authored-by: Samuel Brand <[email protected]>
1 parent ea2766e commit 822a0aa

File tree

6 files changed

+130
-2
lines changed

6 files changed

+130
-2
lines changed

EpiAware/docs/src/developer/contributing.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ Tests that build example package docs from source and inspect the results (end t
4040
located in `/test/examples`. The main entry points are `test/examples/make.jl` for building and
4141
`test/examples/test.jl` for doing some basic checks on the generated outputs.
4242

43+
## Benchmarking
44+
45+
Benchmarking is orchestrated using `PkgBenchmark.jl` along with a GitHub action that uses `BenchmarkCI.jl` The benchmarks are located in `benchmarks/` and the main entry point is `benchmarks/runbenchmarks.jl`.
46+
47+
The main function in the `benchmark` environment is `make_epiaware_suite` which calls `TuringBenchmarking.make_turing_suite` on a set of `Turing` models generated by `EpiAware` benchmarking their sampling with the following autodiff backends:
48+
49+
- `ForwardDiff.jl`.
50+
- `ReverseDiff.jl`: With `compile = false`.
51+
- `ReverseDiff.jl`: With `compile = true`.
52+
53+
### Benchmarking "gotchas"
54+
55+
#### Models with no parameters
56+
57+
In `EpiAware` we do expose some models thats do not have parameters, for example, Poisson sampling with a transformation on a fixed mean process implemented by `TransformObservationModel(NegativeBinomialError())` has no sampleable parameters (although it does contributed log-likelihood as part of a wider model). This causes `TuringBenchmarking.make_turing_suite` to throw an error as it expects all models to have parameters.
58+
4359
## Pluto usage in showcase documentation
4460

4561
Some of the showcase examples in `EpiAware/docs/src/showcase` use [`Pluto.jl`](https://plutojl.org/) notebooks for the underlying computation. The output of the notebooks is rendered into HTML for inclusion in the documentation in two steps:

EpiAware/src/EpiObsModels/EpiObsModels.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek
1111
using ..EpiLatentModels: broadcast_rule, PrefixLatentModel, RepeatEach
1212

1313
using Turing, Distributions, DocStringExtensions, SparseArrays, LinearAlgebra
14-
using LogExpFunctions: xexpy
14+
using LogExpFunctions: xexpy, log1pexp
1515

1616
# Observation error models
1717
export PoissonError, NegativeBinomialError
@@ -21,7 +21,7 @@ export generate_observation_error_priors, observation_error
2121

2222
# Observation model modifiers
2323
export LatentDelay, Ascertainment, PrefixObservationModel, RecordExpectedObs
24-
export Aggregate
24+
export Aggregate, TransformObservationModel
2525

2626
# Observation model manipulators
2727
export StackObservationModels
@@ -36,6 +36,7 @@ include("modifiers/ascertainment/helpers.jl")
3636
include("modifiers/Aggregate.jl")
3737
include("modifiers/PrefixObservationModel.jl")
3838
include("modifiers/RecordExpectedObs.jl")
39+
include("modifiers/TransformObservationModel.jl")
3940
include("StackObservationModels.jl")
4041
include("ObservationErrorModels/methods.jl")
4142
include("ObservationErrorModels/NegativeBinomialError.jl")
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
@doc raw"
2+
The `TransformObservationModel` struct represents an observation model that applies a transformation function to the expected observations before passing them to the underlying observation model.
3+
4+
## Fields
5+
- `model::M`: The underlying observation model.
6+
- `transform::F`: The transformation function applied to the expected observations.
7+
8+
## Constructors
9+
- `TransformObservationModel(model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function.
10+
- `TransformObservationModel(; model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance using named arguments.
11+
- `TransformObservationModel(model::M; transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function.
12+
13+
## Example
14+
15+
```julia
16+
using EpiAware, Distributions, LogExpFunctions
17+
18+
trans_obs = TransformObservationModel(NegativeBinomialError())
19+
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 30))
20+
gen_obs()
21+
```
22+
"
23+
@kwdef struct TransformObservationModel{
24+
M <: AbstractTuringObservationModel, F <: Function} <: AbstractTuringObservationModel
25+
"The underlying observation model."
26+
model::M
27+
"The transformation function. The default is `log1pexp` which is the softplus transformation"
28+
transform::F = x -> log1pexp.(x)
29+
end
30+
31+
function TransformObservationModel(model::M;
32+
transform::F = x -> log1pexp.(x)) where {
33+
M <: AbstractTuringObservationModel, F <: Function}
34+
return TransformObservationModel(model, transform)
35+
end
36+
37+
@doc raw"
38+
Generates observations or accumulates log-likelihood based on the `TransformObservationModel`.
39+
40+
## Arguments
41+
- `obs::TransformObservationModel`: The TransformObservationModel.
42+
- `y_t`: The current state of the observations.
43+
- `Y_t`: The expected observations.
44+
45+
## Returns
46+
- `y_t`: The updated observations.
47+
"
48+
@model function EpiAwareBase.generate_observations(
49+
obs::TransformObservationModel, y_t, Y_t
50+
)
51+
transformed_Y_t = obs.transform(Y_t)
52+
53+
@submodel y_t = generate_observations(obs.model, y_t, transformed_Y_t)
54+
55+
return y_t
56+
end
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@testitem "Test TransformObservationModel constructor" begin
2+
using Turing, LogExpFunctions
3+
4+
# Test default constructor
5+
trans_obs = TransformObservationModel(NegativeBinomialError())
6+
@test trans_obs.model == NegativeBinomialError()
7+
@test trans_obs.transform([1.0, 2.0, 3.0]) == log1pexp.([1.0, 2.0, 3.0])
8+
9+
# Test constructor with custom transform
10+
custom_transform = x -> exp.(x)
11+
trans_obs_custom = TransformObservationModel(NegativeBinomialError(), custom_transform)
12+
@test trans_obs_custom.model == NegativeBinomialError()
13+
@test trans_obs_custom.transform([1.0, 2.0, 3.0]) == exp.([1.0, 2.0, 3.0])
14+
15+
# Test kwarg constructor
16+
trans_obs_kwarg = TransformObservationModel(
17+
model = PoissonError(), transform = custom_transform)
18+
@test trans_obs_kwarg.model == PoissonError()
19+
@test trans_obs_kwarg.transform == custom_transform
20+
end
21+
22+
@testitem "Test TransformObservationModel generate_observations" begin
23+
using Turing, LogExpFunctions, Distributions
24+
25+
# Test with default log1pexp transform
26+
trans_obs = TransformObservationModel(NegativeBinomialError())
27+
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 1))
28+
samples = sample(gen_obs, Prior(), 1000; progress = false)["y_t[1]"]
29+
30+
# Reverse the transform
31+
reversed_samples = samples .|> exp |> x -> x .- 1 .|> log
32+
# Apply the transform again
33+
recovered_samples = log1pexp.(reversed_samples)
34+
35+
@test all(isapprox.(samples, recovered_samples, rtol = 1e-6))
36+
37+
# Test with custom transform and Poisson distribution
38+
custom_transform = x -> x .^ 2 # Square transform
39+
trans_obs_custom = TransformObservationModel(PoissonError(), custom_transform)
40+
gen_obs_custom = generate_observations(trans_obs_custom, missing, fill(5.0, 1))
41+
samples_custom = sample(gen_obs_custom, Prior(), 1000; progress = false)
42+
# Reverse the transform
43+
reversed_samples_custom = sqrt.(samples_custom["y_t[1]"])
44+
# Apply the transform again
45+
recovered_samples_custom = custom_transform.(reversed_samples_custom)
46+
47+
@test all(isapprox.(samples_custom["y_t[1]"], recovered_samples_custom, rtol = 1e-6))
48+
end

benchmark/bench/EpiObsModels/EpiObsModels.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ include("modifiers/ascertainment/Ascertainment.jl")
88
include("modifiers/ascertainment/helpers.jl")
99
include("modifiers/LatentDelay.jl")
1010
include("modifiers/PrefixObservationModel.jl")
11+
include("modifiers/RecordExpectedObs.jl")
12+
include("modifiers/TransformObservationModel.jl")
1113
include("ObservationErrorModels/methods.jl")
1214
include("ObservationErrorModels/NegativeBinomialError.jl")
1315
include("ObservationErrorModels/PoissonError.jl")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
let
2+
transform_obs = TransformObservationModel(NegativeBinomialError())
3+
mdl = generate_observations(transform_obs, fill(10, 10), fill(9, 10))
4+
suite["TransformObservationModel"] = make_epiaware_suite(mdl)
5+
end

0 commit comments

Comments
 (0)