Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Argument order #5

Closed
aplavin opened this issue Apr 16, 2024 · 11 comments
Closed

Argument order #5

aplavin opened this issue Apr 16, 2024 · 11 comments
Assignees

Comments

@aplavin
Copy link

aplavin commented Apr 16, 2024

Nice package, I can see it growing over time to support more functions/distributions!

I wonder if you considered a different argument order, so that to follow regular Statistics.jl mean(func, collection):
use mean(log, Exponetial(10), ClosedFormExpectation()) instead of the current mean(ClosedFormExpectation(), Exponetial(10), log).

This isn't too late to change :)

@Nimrais
Copy link
Collaborator

Nimrais commented Apr 16, 2024

Thank you for your feedback and suggestion! I really appreciate you taking the time to review the package and provide constructive input.

You raise a great point about the argument order. I agree that following the convention of mean(func, collection) from Statistics.jl would make the syntax more intuitive and consistent with the existing ecosystem.

After considering your proposal of mean(log, Exponential(10), ClosedFormExpectation()), I think a slightly different order might make more sense: mean(ClosedFormExpectation(), log, Exponential(10)). The reason for this is that the expectation form (e.g., ClosedFormExpectation()) can actually change the underlying function (log in our case), so maybe it does make sense to it go first.

@Nimrais Nimrais self-assigned this Apr 16, 2024
@aplavin
Copy link
Author

aplavin commented Apr 16, 2024

Yeah, totally, it may make sense for the form to go first.

Maybe I don't fully understand the interface, but isn't mean(ClosedFormExpectation(), func, dist) computing the expectation of func(x) when x ~ dist? Then what do you mean by "the expectation form can actually change the underlying function"?

@Nimrais
Copy link
Collaborator

Nimrais commented Apr 16, 2024

Because for example, we have a structure ClosedWilliamsProduct. And a user can do smt like this

mean(ClosedWilliamsProduct(), log, Exponetial(10))

which will give the user the value of $$E_{q}[ \log(x) \nabla_{\theta} \log q(x; \theta) ].$$

@aplavin
Copy link
Author

aplavin commented Apr 16, 2024

Hmm, I see...

Also, looking at the tests – what does mean(ClosedFormExpectation(), Normal(μ_1, σ_1), log ∘ Normal(μ_2, σ_2)) mean? I don't think Normal(μ_2, σ_2) is a function at all, so how does it make sense to compose it with log?

@Nimrais
Copy link
Collaborator

Nimrais commented Apr 16, 2024

Yes, it just a convention that should be then clarified somewhere. But what I mean by that is the composition of a log and the Normal pdf, the log pdf.

@Nimrais
Copy link
Collaborator

Nimrais commented Apr 16, 2024

So yes Indeed you would not be able to evaluate it as a function, but we are still able to use the result from $\circ$ for the dispatching.

@aplavin
Copy link
Author

aplavin commented Apr 16, 2024

Didn't really expect such a twist :)
Why not simply Base.Fix1(logpdf, Normal(μ_2, σ_2)) instead? It is both properly callable, and can be dispatched on.

@Nimrais
Copy link
Collaborator

Nimrais commented Apr 16, 2024

I think you are making a good point; however, I want to have a bit of flexibility.

My idea boils down to adding an additional Logpdf structure and implementing it in the following way:

import Distributions: Distribution, logpdf

struct Logpdf{D}
    dist::D
end

function (f::Logpdf{D})(args...) where {D <: Distribution}
    return logpdf(f.dist, args...)
end

And the expectation interface will look like this

mean(ClosedFormExpectation(), Logpdf(LogNormal(1, 10)), Exponetial(10))

@aplavin
Copy link
Author

aplavin commented Apr 16, 2024

It's fine if for some specific purpose you need this kind of Logpdf structure in the package. It seems equivalent to Fix1(logpdf) plus arbitrary args... propagation, the latter is unfortunately not supported in Base.

But would still be nice to support mean(Fix1(logpdf, Normal(...)), dist) – this is the natural Julian representation of such a function. There is even convenient syntax for it:

julia> using Accessors

julia> f = @o logpdf(Normal(0, 1), _)
(::Base.Fix1{typeof(logpdf), Normal{Float64}})

@Nimrais
Copy link
Collaborator

Nimrais commented Apr 16, 2024

Sure, Fix1(logpdf, Normal(...)) will be supported I will just convert it into Logpdf at the mean call time.

@Nimrais Nimrais mentioned this issue Apr 16, 2024
@Nimrais
Copy link
Collaborator

Nimrais commented Apr 16, 2024

@aplavin you can check the state in https://github.com/biaslab/ClosedFormExpectations.jl/tree/clean-interface. I think the PR resolves the issue.

@Nimrais Nimrais closed this as completed Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants