-
Notifications
You must be signed in to change notification settings - Fork 34
Description
Description
Following the discussion in #565, I'm extending a structural time series (STS) model, however in that discussion I've provided unconstrained Wishart prior on the full transition precision (which seeks correlations between all components), I need to follow the original STS approach and constrain it as R*Q*R' where Q ~ Wishart(df, S) and R is a known constraint matrix.
The model structure is:
z[t] ~ ContinuousTransition(zprev, F, QR) # where QR = R*Q*R'Before implementing custom rules, I wanted to experiment using existing functionality (ProjectedTo or CVIProjection), but both approaches fail with different errors.
Expected Behavior
Either ProjectedTo(Wishart) or CVIProjection() should allow me to experiment with the constrained Wishart transformation without having to immediately implement custom inference rules.
Actual Behavior
Throws errors.
Minimal Reproducible Example or steps to reproduce the issue
Approach 1: Custom node with ProjectedTo
Following the node fusion example, I created a custom TransformedWishart node:
using RxInfer
struct TransformedWishart{W, R} <: ContinuousMatrixDistribution
w::W
r::R
end
BayesBase.logpdf(dist::TransformedWishart, x) = logpdf(Wishart(size(dist.R, 1) + dist.W.df, dist.R*dist.W.S*transpose(dist.R)), x)
BayesBase.insupport(dist::TransformedWishart, x) = true
@node TransformedWishart Stochastic [out, w, r]
@model function rxsts(y, R)
Q ~ Wishart(2, diageye(2))
zprev ~ MvNormalMeanPrecision(zeros(2), diageye(2))
QR ~ TransformedWishart(Q, R)
F ~ MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
for t in eachindex(y)
z[t] ~ ContinuousTransition(zprev, F, QR)
y[t] ~ Normal(mean = z[t], precision = 1.0)
zprev = z[t]
end
end
@constraints function rxsts_constraints()
q(z, zprev, F, QR, Q) = q(z, zprev)q(F)q(QR)q(Q)
q(QR) :: ProjectedTo(Wishart)
q(Q) :: ProjectedTo(Wishart)
end
@meta function rxsts_meta()
ContinuousTransition() -> CTMeta(transition)
end
@initialization function rxsts_init()
q(F) = MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
q(Q) = Wishart(2, diageye(2))
q(QR) = Wishart(4, diageye(4))
μ(zprev) = MvNormalMeanPrecision(zeros(2), diageye(2))
end
R = [
0.1 0.0;
0.0 0.0;
0.0 1.0;
0.0 0.0;
]
results = infer(
model = rxsts(R=R),
data = (y = [randn(4) for _ in 1:100],),
constraints = rxsts_constraints(),
meta = rxsts_meta(),
initialization = rxsts_init(),
returnvars = KeepLast(),
iterations = 50,
showprogress = true,
options = (
rulefallback = NodeFunctionRuleFallback(),
limit_stack_depth = 250,
)
)Approach 2: Deterministic transformation with CVIProjection
Using a deterministic transformation node:
using RxInfer
TransformWishart(Q, R) = R*Q*R'
transition(F) = reshape(F, 4, 4)
@model function rxsts(y, R)
Q ~ Wishart(2, diageye(2))
zprev ~ MvNormalMeanPrecision(zeros(2), diageye(2))
F ~ MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
QR := TransformWishart(Q, R)
for t in eachindex(y)
z[t] ~ ContinuousTransition(zprev, F, QR)
y[t] ~ Normal(mean = z[t], precision = 1.0)
zprev = z[t]
end
end
@constraints function rxsts_constraints()
q(z, zprev, F, QR, Q) = q(z, zprev)q(F)q(QR)q(Q)
end
@meta function rxsts_meta()
ContinuousTransition() -> CTMeta(transition)
TransformWishart() -> CVIProjection()
end
@initialization function rxsts_init()
q(F) = MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
q(Q) = Wishart(2, diageye(2))
q(QR) = Wishart(4, diageye(4))
μ(zprev) = MvNormalMeanPrecision(zeros(2), diageye(2))
end
R = [
0.1 0.0;
0.0 0.0;
0.0 1.0;
0.0 0.0;
]
results = infer(
model = rxsts(R = R),
data = (y = [randn(4) for _ in 1:100],),
constraints = rxsts_constraints(),
meta = rxsts_meta(),
initialization = rxsts_init(),
returnvars = KeepLast(),
iterations = 50,
showprogress = true,
options = (
limit_stack_depth = 250,
)
)Error Message / Stack Trace
**Approach 1: Custom node with ProjectedTo**
**Error:**
ERROR: MethodError: no method matching get_natural_manifold_base(::Type{Wishart}, ::Tuple{}, ::Nothing)
The function `get_natural_manifold_base` exists, but no method is defined for this combination of argument types.
**Approach 2: Deterministic transformation with CVIProjection**
**Error:**
ERROR: RuleMethodError: no method matching rule for the given arguments
Possible fix, define:
@rule TransformedWishart(:w, Marginalisation) (q_out::Wishart, q_r::PointMass, ) = begin
return ...
endJulia Version
1.11
RxInfer Version
Latest stable
Environment Information
(ppplayground) pkg> status ExponentialFamilyManifolds
Status `~/Julia/blabla/Project.toml`
⌃ [5c9727c4] ExponentialFamilyManifolds v3.0.2Session ID (Optional)
No response
Additional Context
No response
Metadata
Metadata
Assignees
Labels
Type
Projects
Status