Skip to content

[Bug]: Cannot use ProjectedTo or CVIProjection for constrained Wishart transformation #570

@albertpod

Description

@albertpod

Description

Following the discussion in #565, I'm extending a structural time series (STS) model, however in that discussion I've provided unconstrained Wishart prior on the full transition precision (which seeks correlations between all components), I need to follow the original STS approach and constrain it as R*Q*R' where Q ~ Wishart(df, S) and R is a known constraint matrix.

The model structure is:

z[t] ~ ContinuousTransition(zprev, F, QR)  # where QR = R*Q*R'

Before implementing custom rules, I wanted to experiment using existing functionality (ProjectedTo or CVIProjection), but both approaches fail with different errors.

Expected Behavior

Either ProjectedTo(Wishart) or CVIProjection() should allow me to experiment with the constrained Wishart transformation without having to immediately implement custom inference rules.

Actual Behavior

Throws errors.

Minimal Reproducible Example or steps to reproduce the issue

Approach 1: Custom node with ProjectedTo

Following the node fusion example, I created a custom TransformedWishart node:

using RxInfer

struct TransformedWishart{W, R} <: ContinuousMatrixDistribution
    w::W
    r::R
end

BayesBase.logpdf(dist::TransformedWishart, x) = logpdf(Wishart(size(dist.R, 1) + dist.W.df, dist.R*dist.W.S*transpose(dist.R)), x)
BayesBase.insupport(dist::TransformedWishart, x) = true

@node TransformedWishart Stochastic [out, w, r]

@model function rxsts(y, R)
    Q    ~ Wishart(2, diageye(2))
    zprev ~ MvNormalMeanPrecision(zeros(2), diageye(2))
    QR    ~ TransformedWishart(Q, R)
    F     ~ MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
    for t in eachindex(y)
        z[t] ~ ContinuousTransition(zprev, F, QR)
        y[t] ~ Normal(mean = z[t], precision = 1.0)
        zprev = z[t]
    end
end

@constraints function rxsts_constraints()
    q(z, zprev, F, QR, Q) = q(z, zprev)q(F)q(QR)q(Q)
    q(QR) :: ProjectedTo(Wishart)
    q(Q)  :: ProjectedTo(Wishart)
end

@meta function rxsts_meta()
    ContinuousTransition() -> CTMeta(transition)
end

@initialization function rxsts_init()
    q(F)     = MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
    q(Q)     = Wishart(2, diageye(2))
    q(QR)    = Wishart(4, diageye(4))
    μ(zprev) = MvNormalMeanPrecision(zeros(2), diageye(2))
end

R = [
    0.1   0.0;
    0.0   0.0;
    0.0   1.0;
    0.0   0.0;
]

results = infer(
    model = rxsts(R=R),
    data = (y = [randn(4) for _ in 1:100],),
    constraints = rxsts_constraints(),
    meta = rxsts_meta(),
    initialization = rxsts_init(),
    returnvars = KeepLast(),
    iterations = 50,
    showprogress = true,
    options = (
        rulefallback = NodeFunctionRuleFallback(),
        limit_stack_depth = 250,
    )
)

Approach 2: Deterministic transformation with CVIProjection

Using a deterministic transformation node:

using RxInfer

TransformWishart(Q, R) = R*Q*R'

transition(F) = reshape(F, 4, 4)

@model function rxsts(y, R)
    Q     ~ Wishart(2, diageye(2))
    zprev ~ MvNormalMeanPrecision(zeros(2), diageye(2))
    F     ~ MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
    
    QR := TransformWishart(Q, R)
    
    for t in eachindex(y)
        z[t] ~ ContinuousTransition(zprev, F, QR)
        y[t] ~ Normal(mean = z[t], precision = 1.0)
        zprev = z[t]
    end
end

@constraints function rxsts_constraints()
    q(z, zprev, F, QR, Q) = q(z, zprev)q(F)q(QR)q(Q)
end

@meta function rxsts_meta()
    ContinuousTransition() -> CTMeta(transition)
    TransformWishart() -> CVIProjection()
end

@initialization function rxsts_init()
    q(F)     = MvNormalMeanPrecision(vec(diageye(4)), diageye(16))
    q(Q)     = Wishart(2, diageye(2))
    q(QR)    = Wishart(4, diageye(4))
    μ(zprev) = MvNormalMeanPrecision(zeros(2), diageye(2))
end

R = [
    0.1   0.0;
    0.0   0.0;
    0.0   1.0;
    0.0   0.0;
]

results = infer(
    model = rxsts(R = R),
    data = (y = [randn(4) for _ in 1:100],),
    constraints = rxsts_constraints(),
    meta = rxsts_meta(),
    initialization = rxsts_init(),
    returnvars = KeepLast(),
    iterations = 50,
    showprogress = true,
    options = (
        limit_stack_depth = 250,
    )
)

Error Message / Stack Trace

**Approach 1: Custom node with ProjectedTo**

**Error:**

ERROR: MethodError: no method matching get_natural_manifold_base(::Type{Wishart}, ::Tuple{}, ::Nothing)
The function `get_natural_manifold_base` exists, but no method is defined for this combination of argument types.


**Approach 2: Deterministic transformation with CVIProjection**

**Error:**

ERROR: RuleMethodError: no method matching rule for the given arguments
Possible fix, define:
@rule TransformedWishart(:w, Marginalisation) (q_out::Wishart, q_r::PointMass, ) = begin
    return ...
end

Julia Version

1.11

RxInfer Version

Latest stable

Environment Information

(ppplayground) pkg> status ExponentialFamilyManifolds
Status `~/Julia/blabla/Project.toml`
⌃ [5c9727c4] ExponentialFamilyManifolds v3.0.2

Session ID (Optional)

No response

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    Status

    Backlog

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions