-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add model_to_minibatch
transformation to convert all pm.Data
to pm.Minibatch
#7785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
You can use the lower level utility: pymc/pymc/variational/minibatch_rv.py Line 53 in ef26ae8
Then make that a vanilla observed RV |
Ah you already did that, so your question is how to get total size? Grab the batch shape of the variable and constant fold it without raising if it can't be fully folded |
My real issue was not understanding what needs to be the key and value in the replacements, between:
|
the best is usual to replace the whole fgraph |
I don't really understand what that answer means |
dprint the fgraph and it will perhaps be more obvious what I am mumbling |
The problem i was running into was that I ended up with two |
Because Minibatch assumes the data variables have the same length, it might make sense to take a variables argument. Or have some way to group data variables of the same size (same dim name maybe?)
…On Thu, 15 May 2025, 15:35 Ricardo Vieira, ***@***.***> wrote:
*ricardoV94* left a comment (pymc-devs/pymc#7785)
<#7785 (comment)>
dprint the fgraph and it will perhaps be more obvious what I am mumbling
—
Reply to this email directly, view it on GitHub
<#7785 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUMC5VCN6VAAJKNHEMT26SJPZAVCNFSM6AAAAAB5F7LYYKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDQOBTHAZTINZXG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Description
A pain point for me when testing different algorithms (e.g. MCMC vs VI) is that I don't want to write a 2nd version of the model with
pm.Minibatch
on the data.This PR adds a model transformation that does that for the user. It's the reverse of the
remove_minibatched_nodes
transformer that @zaxtax implemented recently.This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the
total_size
set correctly. Help wanted.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7785.org.readthedocs.build/en/7785/