-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[WIP] Porting kroneckernormal distribution to v4 #4774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
pymc3/distributions/multivariate.py
Outdated
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) | ||
self.N = at.prod(self.sizes) | ||
chols = list(map(cholesky, covs)) | ||
chols = [_chol.eval() for _chol in chols] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how practical it is to use an eval()
in a dist
. But this is one way to get the output's of the cholesky matrices from the Op
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we almost never want to use eval
/aesara.function
outside of a place where it's intended to be used (e.g. posterior predictive sampling).
Regardless, why would we want/need to use non-symbolic values here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exact place this comment tags right now i.e. dist
is eval()
free. The .eval()
is currently required over in rng_fn
since the kronecker()
method was returning TensorVariable
(I'm working on replacing that with scipy.linalg.kron
so we can count on eval()
being removed from there too) The other place .eval()
is in is logp
where it is used to infer shapes (Theoretically that can also be removed with a bit of effort.)
Overall with the current implementation (which is a whole lot of omitted code) I wouldn't worry about eval()
as much as the inability of supporting sigma
which us the bigger issue with this implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We simply cannot evaluate things like this unless we place very strong and unnatural restrictions on all of the terms involved. More specifically, the graphs in chols
must have no symbolic inputs; otherwise, this will fail. If the tests are passing with this in place, then our testing isn't sufficient, because it's only ever testing Constant
/SharedVariable
inputs.
@ricardoV94 The current failure is the test from #4488 which is something that you worked upon. The difference between the values is nearly negligible, though you might wanna have a look at it again if it's something that needs to be very precise. |
Maybe try to increase the tolerance to |
pymc3/distributions/multivariate.py
Outdated
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) | ||
self.N = at.prod(self.sizes) | ||
chols = list(map(cholesky, covs)) | ||
chols = [_chol.eval() for _chol in chols] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we almost never want to use eval
/aesara.function
outside of a place where it's intended to be used (e.g. posterior predictive sampling).
Regardless, why would we want/need to use non-symbolic values here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great, I added just a tiny suggestion
pymc3/distributions/multivariate.py
Outdated
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) | ||
self.N = at.prod(self.sizes) | ||
chols = list(map(cholesky, covs)) | ||
chols = [_chol.eval() for _chol in chols] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We simply cannot evaluate things like this unless we place very strong and unnatural restrictions on all of the terms involved. More specifically, the graphs in chols
must have no symbolic inputs; otherwise, this will fail. If the tests are passing with this in place, then our testing isn't sufficient, because it's only ever testing Constant
/SharedVariable
inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great, although I am not familiar with the distribution so I will have to trust the tests on this one. I left a couple of minor suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking if the tests pass one last time before merging.
Great work @kc611!
Co-authored-by: Ricardo <[email protected]>
Porting kroneckernormal distribution to v4 (pymc-devs#4774)
This PR refactors the
KroneckerNormal
distribution to be compatible withv4
.