Skip to content

MarginalApprox doesn't allow non-constant covariance parameters or inducing point locations in v4 #5922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
quantheory opened this issue Jun 23, 2022 · 13 comments · Fixed by #6076
Labels
bug GP Gaussian Process

Comments

@quantheory
Copy link
Contributor

Description of your problem

The following code that worked in PyMC3 (and works in v4 if Marginal is used instead of MarginalApprox) is no longer functional:

x = np.linspace(0., 1., 10)
xu = np.linspace(0., 1., 5)
y = np.sin(x)

with pm.Model():
    sigma_gp = pm.HalfNormal('sigma_gp', sigma=1.)
    l = pm.HalfNormal('l', sigma=0.1)
    cov = sigma_gp**2 * pm.gp.cov.Matern32(1, ls=[l])
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='VFE')
    sigma_noise = pm.HalfNormal('sigma_noise', sigma=1.)
    gp.marginal_likelihood('like', X=x[:,None], y=y,
                           noise=sigma_noise, Xu=xu[:,None])
    maxpost = pm.find_MAP()
    print(maxpost)
Complete error traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_33029/329055196.py in <module>
     11     gp.marginal_likelihood('like', X=x[:,None], y=y,
     12                            noise=sigma_noise, Xu=xu[:,None])
---> 13     maxpost = pm.find_MAP()
     14     print(maxpost)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/tuning/starting.py in find_MAP(start, vars, method, return_raw, include_transformed, progressbar, maxeval, model, seed, *args, **kwargs)
    109     )
    110     start = ipfn(seed)
--> 111     model.check_start_vals(start)
    112 
    113     var_names = {var.name for var in vars}

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in check_start_vals(self, start)
   1785                 )
   1786 
-> 1787             initial_eval = self.point_logps(point=elem)
   1788 
   1789             if not all(np.isfinite(v) for v in initial_eval.values()):

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in point_logps(self, point, round_vals)
   1821 
   1822         factors = self.basic_RVs + self.potentials
-> 1823         factor_logps_fn = [at.sum(factor) for factor in self.logp(factors, sum=False)]
   1824         return {
   1825             factor.name: np.round(np.asarray(factor_logp), round_vals)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in logp(self, vars, jacobian, sum)
    751         rv_logps: List[TensorVariable] = []
    752         if rv_values:
--> 753             rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    754             assert isinstance(rv_logps, list)
    755 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
    255     ]
    256     if unexpected_rv_nodes:
--> 257         raise ValueError(
    258             f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
    259             "This can happen when DensityDist logp or Interval transform functions "

ValueError: Random variables detected in the logp graph: [l, sigma_gp].
This can happen when DensityDist logp or Interval transform functions reference nonlocal variables.

This is probably related to some degree to the DensityDist changes and to #5024. Note that the following case where sigma_noise is the only random variable works fine:

x = np.linspace(0., 1., 10)
xu = np.linspace(0., 1., 5)
y = np.sin(x)

with pm.Model():
    cov = pm.gp.cov.Matern32(1, ls=[1.])
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='VFE')
    sigma_noise = pm.HalfNormal('sigma_noise', sigma=1.)
    gp.marginal_likelihood('like', X=x[:,None], y=y,
                           noise=sigma_noise, Xu=xu[:,None])
    maxpost = pm.find_MAP()
    print(maxpost)

yielding

{'sigma_noise_log__': array(-2.8405184), 'sigma_noise': array(0.05839539)}

I'm guessing that this is why this error was not caught earlier. MarginalApprox needs a test to be added with non-constant length scales.

There's a completely different error if the inducing point locations are non-constant. Example code:

x = np.linspace(0., 1., 10)
xu_init = np.linspace(0., 1., 5)
y = np.sin(x)

with pm.Model():
    cov = pm.gp.cov.Matern32(1, ls=[1.])
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='VFE')
    sigma_noise = pm.HalfNormal('sigma_noise', sigma=1.)
    xu = pm.Flat("xu", shape=(5, 1), initval=xu_init[:,None])
    gp.marginal_likelihood('like', X=x[:,None], y=y,
                           noise=sigma_noise, Xu=xu)
    maxpost = pm.find_MAP()
    print(maxpost)
Complete error traceback
---------------------------------------------------------------------------
MissingInputError                         Traceback (most recent call last)
/tmp/ipykernel_33029/1554036267.py in <module>
     10     gp.marginal_likelihood('like', X=x[:,None], y=y,
     11                            noise=sigma_noise, Xu=xu)
---> 12     maxpost = pm.find_MAP()
     13     print(maxpost)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/tuning/starting.py in find_MAP(start, vars, method, return_raw, include_transformed, progressbar, maxeval, model, seed, *args, **kwargs)
    109     )
    110     start = ipfn(seed)
--> 111     model.check_start_vals(start)
    112 
    113     var_names = {var.name for var in vars}

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in check_start_vals(self, start)
   1785                 )
   1786 
-> 1787             initial_eval = self.point_logps(point=elem)
   1788 
   1789             if not all(np.isfinite(v) for v in initial_eval.values()):

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in point_logps(self, point, round_vals)
   1821 
   1822         factors = self.basic_RVs + self.potentials
-> 1823         factor_logps_fn = [at.sum(factor) for factor in self.logp(factors, sum=False)]
   1824         return {
   1825             factor.name: np.round(np.asarray(factor_logp), round_vals)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in logp(self, vars, jacobian, sum)
    751         rv_logps: List[TensorVariable] = []
    752         if rv_values:
--> 753             rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    754             assert isinstance(rv_logps, list)
    755 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
    233 
    234     transform_opt = TransformValuesOpt(transform_map)
--> 235     temp_logp_var_dict = factorized_joint_logprob(
    236         tmp_rvs_to_values,
    237         extra_rewrites=transform_opt,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aeppl/joint_logprob.py in factorized_joint_logprob(rv_values, warn_missing_rvs, extra_rewrites, **kwargs)
    145         q_rv_inputs = remapped_vars[len(q_value_vars) :]
    146 
--> 147         q_logprob_vars = _logprob(
    148             node.op,
    149             q_value_vars,

~/anaconda3/envs/pymc3/lib/python3.9/functools.py in wrapper(*args, **kw)
    875                             '1 positional argument')
    876 
--> 877         return dispatch(args[0].__class__)(*args, **kw)
    878 
    879     funcname = getattr(func, '__name__', 'singledispatch function')

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/distribution.py in density_dist_logp(op, value_var_list, *dist_params, **kwargs)
    795             _dist_params = dist_params[3:]
    796             value_var = value_var_list[0]
--> 797             return logp(value_var, *_dist_params)
    798 
    799         @_logcdf.register(rv_type)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/gp.py in _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter)
    685     def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
    686         sigma2 = at.square(sigma)
--> 687         Kuu = self.cov_func(Xu)
    688         Kuf = self.cov_func(Xu, X)
    689         Luu = cholesky(stabilize(Kuu, jitter))

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/cov.py in __call__(self, X, Xs, diag)
     84             return self.diag(X)
     85         else:
---> 86             return self.full(X, Xs)
     87 
     88     def diag(self, X):

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/cov.py in full(self, X, Xs)
    507 
    508     def full(self, X, Xs=None):
--> 509         X, Xs = self._slice(X, Xs)
    510         r = self.euclidean_dist(X, Xs)
    511         return (1.0 + np.sqrt(3.0) * r) * at.exp(-np.sqrt(3.0) * r)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/cov.py in _slice(self, X, Xs)
     95         xdims = X.shape[-1]
     96         if isinstance(xdims, Variable):
---> 97             xdims = xdims.eval()
     98         if self.input_dim != xdims:
     99             warnings.warn(

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/basic.py in eval(self, inputs_to_values)
    597         inputs = tuple(sorted(inputs_to_values.keys(), key=id))
    598         if inputs not in self._fn_cache:
--> 599             self._fn_cache[inputs] = function(inputs, self)
    600         args = [inputs_to_values[param] for param in inputs]
    601 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/__init__.py in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    315         # note: pfunc will also call orig_function -- orig_function is
    316         #      a choke point that all compilation must pass through
--> 317         fn = pfunc(
    318             params=inputs,
    319             outputs=outputs,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/pfunc.py in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    372     )
    373 
--> 374     return orig_function(
    375         inputs,
    376         cloned_outputs,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/types.py in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1749     try:
   1750         Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1751         m = Maker(
   1752             inputs,
   1753             outputs,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/types.py in __init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
   1507         indices = [[input, None, [input]] for input in inputs]
   1508 
-> 1509         fgraph, found_updates = std_fgraph(
   1510             inputs, outputs, accept_inplace, fgraph=fgraph
   1511         )

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/types.py in std_fgraph(input_specs, output_specs, accept_inplace, fgraph, features, force_clone)
    228             clone |= spec.variable.owner is not None
    229 
--> 230         fgraph = FunctionGraph(
    231             input_vars,
    232             [spec.variable for spec in output_specs] + updates,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in __init__(self, inputs, outputs, features, clone, update_mapping, **clone_kwds)
    151 
    152         for output in outputs:
--> 153             self.add_output(output, reason="init")
    154 
    155         self.profile = None

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in add_output(self, var, reason, import_missing)
    161         """Add a new variable as an output to this `FunctionGraph`."""
    162         self.outputs.append(var)
--> 163         self.import_var(var, reason=reason, import_missing=import_missing)
    164         self.clients[var].append(("output", len(self.outputs) - 1))
    165 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in import_var(self, var, reason, import_missing)
    302         # Imports the owners of the variables
    303         if var.owner and var.owner not in self.apply_nodes:
--> 304             self.import_node(var.owner, reason=reason, import_missing=import_missing)
    305         elif (
    306             var.owner is None

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in import_node(self, apply_node, check, reason, import_missing)
    367                                 "for more information on this error."
    368                             )
--> 369                             raise MissingInputError(error_msg, variable=var)
    370 
    371         for node in new_nodes:

MissingInputError: Input 0 (xu) of the graph (indices start from 0), used to compute Shape(xu), was not provided and not given a value. Use the Aesara flag exception_verbosity='high', for more information on this error.
 
Backtrace when that variable is created:

  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3169, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3361, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_33029/1554036267.py", line 9, in <module>
    xu = pm.Flat("xu", shape=(5, 1), initval=xu_init[:,None])
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/continuous.py", line 364, in __new__
    return super().__new__(cls, *args, **kwargs)
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/distribution.py", line 271, in __new__
    rv_out = model.register_rv(
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py", line 1359, in register_rv
    self.create_value_var(rv_var, transform)
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py", line 1509, in create_value_var
    value_var = rv_var.type()

I don't really understand this error, since it's not really clear to me when or how the shape information is propagated through either the Aesara variables or PyMC wrappers. So I'm not sure if this error has the same underlying cause as the other or not.

I will say that this regression is rather disappointing, since being unable to tune either the hyperparameters or the inducing points makes MarginalApprox much less useful.

Versions and main components

  • PyMC/PyMC3 Version: 4.0.1
  • Aesara/Theano Version: 2.7.3
  • Python Version: 3.9.5
  • Operating system: Linux
  • How did you install PyMC/PyMC3: conda
@michaelosthege
Copy link
Member

@bwengals @lucianopaz ths looks like either the refactoring of MarginalApprox to use the new DensityDist was incomplete?
My hunch would be that both issues described above weould be resolved by refactoring towards Potential?

@michaelosthege michaelosthege added bug GP Gaussian Process labels Jun 23, 2022
@quantheory
Copy link
Contributor Author

I made a copy of the MarginalApprox class that is identical except for using Potential instead of DensityDist in the marginal_likelihood method; specifically:

        if is_observed:
            return pm.Potential(
                name,
                self._build_marginal_likelihood_logp(
                    y,
                    X,
                    Xu,
                    self.sigma,
                    jitter
                ),
                **kwargs,
            )

With this change I can successfully run an example where both the hyperparameters and inducing points are optimized:

x = np.linspace(0., 1., 10)
xu_init = np.linspace(0., 1., 5)
y = np.sin(x)

with pm.Model():
    sigma_gp = pm.HalfNormal('sigma_gp', sigma=1.)
    l = pm.HalfNormal('l', sigma=0.1)
    cov = sigma_gp**2 * pm.gp.cov.Matern32(1, ls=[l])
    gp = MarginalApprox(cov_func=cov, approx='VFE')
    sigma_noise = pm.HalfNormal('sigma_noise', sigma=1.)
    xu = pm.Flat('xu', shape=(5, 1), initval=xu_init[:,None])
    gp.marginal_likelihood('like', X=x[:,None], y=y,
                           noise=sigma_noise, Xu=xu)
    maxpost = pm.find_MAP()
    print(maxpost)

which outputs

{'sigma_gp_log__': array(0.001806), 'l_log__': array(-2.30458196), 'sigma_noise_log__': array(0.00020157), 'xu': array([[0.00267675],
       [0.24694185],
       [0.49314526],
       [0.75139309],
       [0.98770933]]), 'sigma_gp': array(1.00180763), 'l': array(0.09980051), 'sigma_noise': array(1.00020159)}

Now, I find this to be an implausible result even for this entirely artifical problem. (Somehow the arbitrary initial values for all these variables just happen to be very close to the MAP values?) However, find_MAP has often been producing bad results in v4 (see #5923), so I think that is a separate issue.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 27, 2022

The error means the DensityDist Dist logp function is likely referencing other model variables that are not passed as explicit arguments to the logp function and to the DensityDist class. They seem to come from self.cov, and self.mean which should be explicit inputs to the Densitydist.

Whether the DensityDist is the right object for this or not, depends on whether one needs to to sample the variable defined by the DensityDist. Potentials do not have variables being sampled (although one can achieve the same by combining it with Flat variables)

pymc/pymc/gp/gp.py

Lines 685 to 787 in f5d3431

def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
sigma2 = at.square(sigma)
Kuu = self.cov_func(Xu)
Kuf = self.cov_func(Xu, X)
Luu = cholesky(stabilize(Kuu, jitter))
A = solve_lower(Luu, Kuf)
Qffd = at.sum(A * A, 0)
if self.approx == "FITC":
Kffd = self.cov_func(X, diag=True)
Lamd = at.clip(Kffd - Qffd, 0.0, np.inf) + sigma2
trace = 0.0
elif self.approx == "VFE":
Lamd = at.ones_like(Qffd) * sigma2
trace = (1.0 / (2.0 * sigma2)) * (
at.sum(self.cov_func(X, diag=True)) - at.sum(at.sum(A * A, 0))
)
else: # DTC
Lamd = at.ones_like(Qffd) * sigma2
trace = 0.0
A_l = A / Lamd
L_B = cholesky(at.eye(Xu.shape[0]) + at.dot(A_l, at.transpose(A)))
r = y - self.mean_func(X)
r_l = r / Lamd
c = solve_lower(L_B, at.dot(A, r_l))
constant = 0.5 * X.shape[0] * at.log(2.0 * np.pi)
logdet = 0.5 * at.sum(at.log(Lamd)) + at.sum(at.log(at.diag(L_B)))
quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c))
return -1.0 * (constant + logdet + quadratic + trace)
def marginal_likelihood(
self, name, X, Xu, y, noise=None, is_observed=True, jitter=0.0, **kwargs
):
R"""
Returns the approximate marginal likelihood distribution, given the input
locations `X`, inducing point locations `Xu`, data `y`, and white noise
standard deviations `sigma`.
Parameters
----------
name: string
Name of the random variable
X: array-like
Function input values. If one-dimensional, must be a column
vector with shape `(n, 1)`.
Xu: array-like
The inducing points. Must have the same number of columns as `X`.
y: array-like
Data that is the sum of the function with the GP prior and Gaussian
noise. Must have shape `(n, )`.
noise: scalar, Variable
Standard deviation of the Gaussian noise.
is_observed: bool
Whether to set `y` as an `observed` variable in the `model`.
Default is `True`.
jitter: scalar
A small correction added to the diagonal of positive semi-definite
covariance matrices to ensure numerical stability. Default value is 0.0.
**kwargs
Extra keyword arguments that are passed to `MvNormal` distribution
constructor.
"""
self.X = X
self.Xu = Xu
self.y = y
if noise is None:
raise ValueError("noise argument must be specified")
else:
self.sigma = noise
if is_observed:
return pm.DensityDist(
name,
X,
Xu,
self.sigma,
jitter,
logp=self._build_marginal_likelihood_logp,
observed=y,
ndims_params=[2, 2, 0],
size=X.shape[0],
**kwargs,
)
else:
warnings.warn(
"The 'is_observed' argument has been deprecated. If the GP is "
"unobserved use gp.Latent instead.",
FutureWarning,
)
return pm.DensityDist(
name,
X,
Xu,
self.sigma,
jitter,
logp=self._build_marginal_likelihood_logp,
observed=y,
ndims_params=[2, 2, 0],
# ndim_supp=1,
size=X.shape[0],
**kwargs,
)

@bwengals
Copy link
Contributor

bwengals commented Jul 29, 2022

Apologies for missing this!

Whether the DensityDist is the right object for this or not, depends on whether one needs to to sample the variable defined by the DensityDist.

By that measure, I think the answer is yes. This is really just a sort of weird approximated logp of a multivariate normal. Should have provided a random arg to DensityDist, but you could use conditional instead. Is this sort of thing something that's better in aeppl? Have been meaning to ask you about it @ricardoV94

Is the fix just needing self.cov and self.mean passed in as inputs -- basically whatever needs to happen to decorate _build_marginal_likelihood_logp as a staticmethod?

@ricardoV94
Copy link
Member

Is this sort of thing something that's better in aeppl? Have been meaning to ask you about it @ricardoV94

I don't have a good mental model of what this does to answer, but it seems like... no? Aeppl is about converting graphs of RandomVariables to respective closed form logp graphs.

Is the fix just needing self.cov and self.mean passed in as inputs -- basically whatever needs to happen to decorate _build_marginal_likelihood_logp as a staticmethod?

Yes. It just needs all non local (non constant) TensorVariables that are used in the logp function to be passed as explicit inputs. You just need to be careful because you are working with a Multivariate DensityDist to make sure you pass the right ndim_supp and ndim_params.

This test shows how the DensityDist is used to reimplement a mvnormal

def test_density_dist_multivariate_logp(size):

@bwengals
Copy link
Contributor

Thanks, RE aeppl I see, then no.

That should be a pretty easy fix then. Thanks for the test, that'll help too

@bwengals
Copy link
Contributor

Actually, this might be a bit tricker, and maybe is more aeppl-ish. self.mean_func and self.cov_func are functions, and DensityDist immediately tries to convert them to tensor variables,

NotImplementedError: Cannot convert <pymc.gp.mean.Constant object at 0x1a22da400> to a tensor variable.

Is there a way around this? It probably wouldn't be a good idea to try and unpack the RVs that might be inside mean_func or cov_func.

  • So gp.MarginalApprox.marginal_likelihood should return a variation of a MvNormal random variable whose the logp is calculated in a particular, non-standard, way (it's the extra trace term) -- it's not a regular MvNormal whose mu vector and cov matrix inputs can be pre-calculated and passed in.

  • On the other hand, gp.MarginalApprox.conditional should return an actual MvNormal random variable. The logp here comes from a MvNormal with a particular calculation of mu vector and cov matrix (see here).

  • The random method of the marginal_likelihood random variable is the same as the random method for conditional.

Since aeppl gives logp implementations for different RVs, I think it might make sense here?

@ricardoV94
Copy link
Member

It probably wouldn't be a good idea to try and unpack the RVs that might be inside mean_func or cov_func.

Why not?

@bwengals
Copy link
Contributor

bwengals commented Aug 1, 2022

A mean_func is a thin wrapper around anything really. A cov_func could be sums or products of other cov_funcs, any of which could hold references to any number of rvs. I think it'd be complicated to unpack and then reconstruct them.

Just to be sure, I should have been more specific when I said "the sort of thing that's better in aeppl". What I'm picturing is instead of DensityDist, define something like a "GP" random variable that registers different logp and random methods depending on the approximation or other things. Not that I was suggesting to close this issue by actually porting it into aeppl.

I'm not super familiar with aeppl, so apologies if I'm not quite getting it.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 1, 2022

The only issue I see here is that any RV used in the logp needs to be an input to the DensityDist or whatever Distribution class is used instead. There is no way around that, because Aeppl needs to know all the RV dependencies in advance before the logp methods are called.

DensityDist makes sense if the logp function itself changes across different GPs, otherwise you could create a single GPRV (or GPMvNormalRV or whatever you want to call it) which has a single logp function. Regardless, the constraint that any RV used in the logp must be an explicit input to the Densitydist/new RV remains the same.

Is the set of RVs used in the logp impossible to know in advance / when you are calling the DensityDist?

@ricardoV94
Copy link
Member

Sounds like the problems arise from the GP module creating lazy Cov/ Mean Python objects instead of Aesara graphs which would demand explicit inputs from the get-go.

@shoerman
Copy link

shoerman commented Aug 23, 2022

I am getting the same error with the MarginalSparse Process. Am I correct in concluding that this makes it impossible to do multivariate SGPR with PyMC4?

@bwengals
Copy link
Contributor

@shoerman at the moment yes, but it will be fixed soon, #6076

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug GP Gaussian Process
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants