-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Issue 471: Add TransfromObservationModel (#496)
* add TransfromObservationModel * review unit tests * improve tests * add benchmarks: * add ref to softplus * Update TransformObservationModel.jl * Update TransformObservationModel.jl * Update TransformObservationModel.jl * Update EpiAware/src/EpiObsModels/modifiers/TransformObservationModel.jl Co-authored-by: Samuel Brand <[email protected]> * use negbin so there is sampled parameter * Choose a model with a parameter to sample and document the gotcha * catch a whitespace --------- Co-authored-by: Samuel Brand <[email protected]> Co-authored-by: Samuel Brand <[email protected]>
- Loading branch information
1 parent
ea2766e
commit 822a0aa
Showing
6 changed files
with
130 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
EpiAware/src/EpiObsModels/modifiers/TransformObservationModel.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
@doc raw" | ||
The `TransformObservationModel` struct represents an observation model that applies a transformation function to the expected observations before passing them to the underlying observation model. | ||
## Fields | ||
- `model::M`: The underlying observation model. | ||
- `transform::F`: The transformation function applied to the expected observations. | ||
## Constructors | ||
- `TransformObservationModel(model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function. | ||
- `TransformObservationModel(; model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance using named arguments. | ||
- `TransformObservationModel(model::M; transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function. | ||
## Example | ||
```julia | ||
using EpiAware, Distributions, LogExpFunctions | ||
trans_obs = TransformObservationModel(NegativeBinomialError()) | ||
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 30)) | ||
gen_obs() | ||
``` | ||
" | ||
@kwdef struct TransformObservationModel{ | ||
M <: AbstractTuringObservationModel, F <: Function} <: AbstractTuringObservationModel | ||
"The underlying observation model." | ||
model::M | ||
"The transformation function. The default is `log1pexp` which is the softplus transformation" | ||
transform::F = x -> log1pexp.(x) | ||
end | ||
|
||
function TransformObservationModel(model::M; | ||
transform::F = x -> log1pexp.(x)) where { | ||
M <: AbstractTuringObservationModel, F <: Function} | ||
return TransformObservationModel(model, transform) | ||
end | ||
|
||
@doc raw" | ||
Generates observations or accumulates log-likelihood based on the `TransformObservationModel`. | ||
## Arguments | ||
- `obs::TransformObservationModel`: The TransformObservationModel. | ||
- `y_t`: The current state of the observations. | ||
- `Y_t`: The expected observations. | ||
## Returns | ||
- `y_t`: The updated observations. | ||
" | ||
@model function EpiAwareBase.generate_observations( | ||
obs::TransformObservationModel, y_t, Y_t | ||
) | ||
transformed_Y_t = obs.transform(Y_t) | ||
|
||
@submodel y_t = generate_observations(obs.model, y_t, transformed_Y_t) | ||
|
||
return y_t | ||
end |
48 changes: 48 additions & 0 deletions
48
EpiAware/test/EpiObsModels/modifiers/TransformObservationModel.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
@testitem "Test TransformObservationModel constructor" begin | ||
using Turing, LogExpFunctions | ||
|
||
# Test default constructor | ||
trans_obs = TransformObservationModel(NegativeBinomialError()) | ||
@test trans_obs.model == NegativeBinomialError() | ||
@test trans_obs.transform([1.0, 2.0, 3.0]) == log1pexp.([1.0, 2.0, 3.0]) | ||
|
||
# Test constructor with custom transform | ||
custom_transform = x -> exp.(x) | ||
trans_obs_custom = TransformObservationModel(NegativeBinomialError(), custom_transform) | ||
@test trans_obs_custom.model == NegativeBinomialError() | ||
@test trans_obs_custom.transform([1.0, 2.0, 3.0]) == exp.([1.0, 2.0, 3.0]) | ||
|
||
# Test kwarg constructor | ||
trans_obs_kwarg = TransformObservationModel( | ||
model = PoissonError(), transform = custom_transform) | ||
@test trans_obs_kwarg.model == PoissonError() | ||
@test trans_obs_kwarg.transform == custom_transform | ||
end | ||
|
||
@testitem "Test TransformObservationModel generate_observations" begin | ||
using Turing, LogExpFunctions, Distributions | ||
|
||
# Test with default log1pexp transform | ||
trans_obs = TransformObservationModel(NegativeBinomialError()) | ||
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 1)) | ||
samples = sample(gen_obs, Prior(), 1000; progress = false)["y_t[1]"] | ||
|
||
# Reverse the transform | ||
reversed_samples = samples .|> exp |> x -> x .- 1 .|> log | ||
# Apply the transform again | ||
recovered_samples = log1pexp.(reversed_samples) | ||
|
||
@test all(isapprox.(samples, recovered_samples, rtol = 1e-6)) | ||
|
||
# Test with custom transform and Poisson distribution | ||
custom_transform = x -> x .^ 2 # Square transform | ||
trans_obs_custom = TransformObservationModel(PoissonError(), custom_transform) | ||
gen_obs_custom = generate_observations(trans_obs_custom, missing, fill(5.0, 1)) | ||
samples_custom = sample(gen_obs_custom, Prior(), 1000; progress = false) | ||
# Reverse the transform | ||
reversed_samples_custom = sqrt.(samples_custom["y_t[1]"]) | ||
# Apply the transform again | ||
recovered_samples_custom = custom_transform.(reversed_samples_custom) | ||
|
||
@test all(isapprox.(samples_custom["y_t[1]"], recovered_samples_custom, rtol = 1e-6)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
benchmark/bench/EpiObsModels/modifiers/TransformObservationModel.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
let | ||
transform_obs = TransformObservationModel(NegativeBinomialError()) | ||
mdl = generate_observations(transform_obs, fill(10, 10), fill(9, 10)) | ||
suite["TransformObservationModel"] = make_epiaware_suite(mdl) | ||
end |