Skip to content

Commit

Permalink
Issue 471: Add TransfromObservationModel (#496)
Browse files Browse the repository at this point in the history
* add TransfromObservationModel

* review unit tests

* improve tests

* add benchmarks:

* add ref to softplus

* Update TransformObservationModel.jl

* Update TransformObservationModel.jl

* Update TransformObservationModel.jl

* Update EpiAware/src/EpiObsModels/modifiers/TransformObservationModel.jl

Co-authored-by: Samuel Brand <[email protected]>

* use negbin so there is sampled parameter

* Choose a model with a parameter to sample and document the gotcha

* catch a whitespace

---------

Co-authored-by: Samuel Brand <[email protected]>
Co-authored-by: Samuel Brand <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2024
1 parent ea2766e commit 822a0aa
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 2 deletions.
16 changes: 16 additions & 0 deletions EpiAware/docs/src/developer/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ Tests that build example package docs from source and inspect the results (end t
located in `/test/examples`. The main entry points are `test/examples/make.jl` for building and
`test/examples/test.jl` for doing some basic checks on the generated outputs.

## Benchmarking

Benchmarking is orchestrated using `PkgBenchmark.jl` along with a GitHub action that uses `BenchmarkCI.jl` The benchmarks are located in `benchmarks/` and the main entry point is `benchmarks/runbenchmarks.jl`.

The main function in the `benchmark` environment is `make_epiaware_suite` which calls `TuringBenchmarking.make_turing_suite` on a set of `Turing` models generated by `EpiAware` benchmarking their sampling with the following autodiff backends:

- `ForwardDiff.jl`.
- `ReverseDiff.jl`: With `compile = false`.
- `ReverseDiff.jl`: With `compile = true`.

### Benchmarking "gotchas"

#### Models with no parameters

In `EpiAware` we do expose some models thats do not have parameters, for example, Poisson sampling with a transformation on a fixed mean process implemented by `TransformObservationModel(NegativeBinomialError())` has no sampleable parameters (although it does contributed log-likelihood as part of a wider model). This causes `TuringBenchmarking.make_turing_suite` to throw an error as it expects all models to have parameters.

## Pluto usage in showcase documentation

Some of the showcase examples in `EpiAware/docs/src/showcase` use [`Pluto.jl`](https://plutojl.org/) notebooks for the underlying computation. The output of the notebooks is rendered into HTML for inclusion in the documentation in two steps:
Expand Down
5 changes: 3 additions & 2 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek
using ..EpiLatentModels: broadcast_rule, PrefixLatentModel, RepeatEach

using Turing, Distributions, DocStringExtensions, SparseArrays, LinearAlgebra
using LogExpFunctions: xexpy
using LogExpFunctions: xexpy, log1pexp

# Observation error models
export PoissonError, NegativeBinomialError
Expand All @@ -21,7 +21,7 @@ export generate_observation_error_priors, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, PrefixObservationModel, RecordExpectedObs
export Aggregate
export Aggregate, TransformObservationModel

# Observation model manipulators
export StackObservationModels
Expand All @@ -36,6 +36,7 @@ include("modifiers/ascertainment/helpers.jl")
include("modifiers/Aggregate.jl")
include("modifiers/PrefixObservationModel.jl")
include("modifiers/RecordExpectedObs.jl")
include("modifiers/TransformObservationModel.jl")
include("StackObservationModels.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
Expand Down
56 changes: 56 additions & 0 deletions EpiAware/src/EpiObsModels/modifiers/TransformObservationModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
@doc raw"
The `TransformObservationModel` struct represents an observation model that applies a transformation function to the expected observations before passing them to the underlying observation model.
## Fields
- `model::M`: The underlying observation model.
- `transform::F`: The transformation function applied to the expected observations.
## Constructors
- `TransformObservationModel(model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function.
- `TransformObservationModel(; model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance using named arguments.
- `TransformObservationModel(model::M; transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function.
## Example
```julia
using EpiAware, Distributions, LogExpFunctions
trans_obs = TransformObservationModel(NegativeBinomialError())
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 30))
gen_obs()
```
"
@kwdef struct TransformObservationModel{
M <: AbstractTuringObservationModel, F <: Function} <: AbstractTuringObservationModel
"The underlying observation model."
model::M
"The transformation function. The default is `log1pexp` which is the softplus transformation"
transform::F = x -> log1pexp.(x)
end

function TransformObservationModel(model::M;
transform::F = x -> log1pexp.(x)) where {
M <: AbstractTuringObservationModel, F <: Function}
return TransformObservationModel(model, transform)
end

@doc raw"
Generates observations or accumulates log-likelihood based on the `TransformObservationModel`.
## Arguments
- `obs::TransformObservationModel`: The TransformObservationModel.
- `y_t`: The current state of the observations.
- `Y_t`: The expected observations.
## Returns
- `y_t`: The updated observations.
"
@model function EpiAwareBase.generate_observations(
obs::TransformObservationModel, y_t, Y_t
)
transformed_Y_t = obs.transform(Y_t)

@submodel y_t = generate_observations(obs.model, y_t, transformed_Y_t)

return y_t
end
48 changes: 48 additions & 0 deletions EpiAware/test/EpiObsModels/modifiers/TransformObservationModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
@testitem "Test TransformObservationModel constructor" begin
using Turing, LogExpFunctions

# Test default constructor
trans_obs = TransformObservationModel(NegativeBinomialError())
@test trans_obs.model == NegativeBinomialError()
@test trans_obs.transform([1.0, 2.0, 3.0]) == log1pexp.([1.0, 2.0, 3.0])

# Test constructor with custom transform
custom_transform = x -> exp.(x)
trans_obs_custom = TransformObservationModel(NegativeBinomialError(), custom_transform)
@test trans_obs_custom.model == NegativeBinomialError()
@test trans_obs_custom.transform([1.0, 2.0, 3.0]) == exp.([1.0, 2.0, 3.0])

# Test kwarg constructor
trans_obs_kwarg = TransformObservationModel(
model = PoissonError(), transform = custom_transform)
@test trans_obs_kwarg.model == PoissonError()
@test trans_obs_kwarg.transform == custom_transform
end

@testitem "Test TransformObservationModel generate_observations" begin
using Turing, LogExpFunctions, Distributions

# Test with default log1pexp transform
trans_obs = TransformObservationModel(NegativeBinomialError())
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 1))
samples = sample(gen_obs, Prior(), 1000; progress = false)["y_t[1]"]

# Reverse the transform
reversed_samples = samples .|> exp |> x -> x .- 1 .|> log
# Apply the transform again
recovered_samples = log1pexp.(reversed_samples)

@test all(isapprox.(samples, recovered_samples, rtol = 1e-6))

# Test with custom transform and Poisson distribution
custom_transform = x -> x .^ 2 # Square transform
trans_obs_custom = TransformObservationModel(PoissonError(), custom_transform)
gen_obs_custom = generate_observations(trans_obs_custom, missing, fill(5.0, 1))
samples_custom = sample(gen_obs_custom, Prior(), 1000; progress = false)
# Reverse the transform
reversed_samples_custom = sqrt.(samples_custom["y_t[1]"])
# Apply the transform again
recovered_samples_custom = custom_transform.(reversed_samples_custom)

@test all(isapprox.(samples_custom["y_t[1]"], recovered_samples_custom, rtol = 1e-6))
end
2 changes: 2 additions & 0 deletions benchmark/bench/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ include("modifiers/ascertainment/Ascertainment.jl")
include("modifiers/ascertainment/helpers.jl")
include("modifiers/LatentDelay.jl")
include("modifiers/PrefixObservationModel.jl")
include("modifiers/RecordExpectedObs.jl")
include("modifiers/TransformObservationModel.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
include("ObservationErrorModels/PoissonError.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
let
transform_obs = TransformObservationModel(NegativeBinomialError())
mdl = generate_observations(transform_obs, fill(10, 10), fill(9, 10))
suite["TransformObservationModel"] = make_epiaware_suite(mdl)
end

0 comments on commit 822a0aa

Please sign in to comment.